1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

style(python): modernize type annotations in protobuf.py

This commit is contained in:
matejcik 2024-06-12 14:11:53 +02:00
parent eaeb58fb25
commit 27fef37cc9

View File

@ -22,23 +22,25 @@ For de-serializing (loading) protobuf types, object with `Reader` interface is r
For serializing (dumping) protobuf types, object with `Writer` interface is required. For serializing (dumping) protobuf types, object with `Writer` interface is required.
""" """
from __future__ import annotations
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from io import BytesIO from io import BytesIO
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import typing as t
from typing_extensions import Protocol, TypeGuard import typing_extensions as tx
T = TypeVar("T", bound=type) T = t.TypeVar("T", bound=type)
MT = TypeVar("MT", bound="MessageType") MT = t.TypeVar("MT", bound="MessageType")
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB MAX_FIELD_SIZE = 1024 * 1024 # 1 MB
class Reader(Protocol): class Reader(tx.Protocol):
def readinto(self, __buf: bytearray) -> int: def readinto(self, __buf: bytearray) -> int:
""" """
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read, Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
@ -47,7 +49,7 @@ class Reader(Protocol):
... ...
class Writer(Protocol): class Writer(tx.Protocol):
def write(self, __buf: bytes) -> int: def write(self, __buf: bytes) -> int:
""" """
Writes all bytes from `buffer`, or raises `EOFError` Writes all bytes from `buffer`, or raises `EOFError`
@ -60,7 +62,7 @@ _UVARINT_BUFFER = bytearray(1)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]: def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
return isinstance(value, type) and issubclass(value, cls) return isinstance(value, type) and issubclass(value, cls)
@ -189,15 +191,15 @@ class _MessageTypeMeta(type):
class MessageType(metaclass=_MessageTypeMeta): class MessageType(metaclass=_MessageTypeMeta):
MESSAGE_WIRE_TYPE: Optional[int] = None MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None
FIELDS: Dict[int, Field] = {} FIELDS: t.ClassVar[dict[int, Field]] = {}
@classmethod @classmethod
def get_field(cls, name: str) -> Optional[Field]: def get_field(cls, name: str) -> Field | None:
return next((f for f in cls.FIELDS.values() if f.name == name), None) return next((f for f in cls.FIELDS.values() if f.name == name), None)
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
if args: if args:
warnings.warn( warnings.warn(
"Positional arguments for MessageType are deprecated", "Positional arguments for MessageType are deprecated",
@ -220,7 +222,7 @@ class MessageType(metaclass=_MessageTypeMeta):
# set in args but not in kwargs # set in args but not in kwargs
setattr(self, field.name, val) setattr(self, field.name, val)
else: else:
default: Any default: t.Any
# not set at all, pick a default # not set at all, pick a default
if field.repeated: if field.repeated:
default = [] default = []
@ -235,7 +237,7 @@ class MessageType(metaclass=_MessageTypeMeta):
default = field.default default = field.default
setattr(self, field.name, default) setattr(self, field.name, default)
def __eq__(self, rhs: Any) -> bool: def __eq__(self, rhs: t.Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self) -> str: def __repr__(self) -> str:
@ -276,9 +278,7 @@ class CountingWriter:
return nwritten return nwritten
def get_field_type_object( def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
field: Field,
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
from . import messages from . import messages
field_type_object = getattr(messages, field.type, None) field_type_object = getattr(messages, field.type, None)
@ -287,7 +287,7 @@ def get_field_type_object(
return field_type_object return field_type_object
def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]: def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
assert field.repeated, "Not decoding packed array into non-repeated field" assert field.repeated, "Not decoding packed array into non-repeated field"
length = load_uvarint(reader) length = load_uvarint(reader)
packed_reader = LimitedReader(reader, length) packed_reader = LimitedReader(reader, length)
@ -300,7 +300,7 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
return values return values
def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]: def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded" assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
value = load_uvarint(reader) value = load_uvarint(reader)
@ -336,7 +336,7 @@ def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnu
def decode_length_delimited_field( def decode_length_delimited_field(
field: Field, reader: Reader field: Field, reader: Reader
) -> Union[bytes, str, MessageType]: ) -> bytes | str | MessageType:
value = load_uvarint(reader) value = load_uvarint(reader)
if value > MAX_FIELD_SIZE: if value > MAX_FIELD_SIZE:
raise ValueError(f"Field {field.name} contents too large ({value} bytes)") raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
@ -358,8 +358,8 @@ def decode_length_delimited_field(
raise TypeError # field type is unknown raise TypeError # field type is unknown
def load_message(reader: Reader, msg_type: Type[MT]) -> MT: def load_message(reader: Reader, msg_type: type[MT]) -> MT:
msg_dict: Dict[str, Any] = {} msg_dict: dict[str, t.Any] = {}
# pre-seed the dict # pre-seed the dict
for field in msg_type.FIELDS.values(): for field in msg_type.FIELDS.values():
if field.repeated: if field.repeated:
@ -500,8 +500,8 @@ def format_message(
pb: "MessageType", pb: "MessageType",
indent: int = 0, indent: int = 0,
sep: str = " " * 4, sep: str = " " * 4,
truncate_after: Optional[int] = 256, truncate_after: int | None = 256,
truncate_to: Optional[int] = 64, truncate_to: int | None = 64,
) -> str: ) -> str:
def mostly_printable(bytes: bytes) -> bool: def mostly_printable(bytes: bytes) -> bool:
if not bytes: if not bytes:
@ -509,7 +509,7 @@ def format_message(
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E) printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
return printable / len(bytes) > 0.8 return printable / len(bytes) > 0.8
def pformat(name: str, value: Any, indent: int) -> str: def pformat(name: str, value: t.Any, indent: int) -> str:
level = sep * indent level = sep * indent
leadin = sep * (indent + 1) leadin = sep * (indent + 1)
@ -568,7 +568,7 @@ def format_message(
) )
def value_to_proto(field: Field, value: Any) -> Any: def value_to_proto(field: Field, value: t.Any) -> t.Any:
field_type_object = get_field_type_object(field) field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType): if safe_issubclass(field_type_object, MessageType):
raise TypeError("value_to_proto only converts simple values") raise TypeError("value_to_proto only converts simple values")
@ -601,7 +601,7 @@ def value_to_proto(field: Field, value: Any) -> Any:
raise TypeError(f"can't convert {type(value)} value to bytes") raise TypeError(f"can't convert {type(value)} value to bytes")
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
params = {} params = {}
for field in message_type.FIELDS.values(): for field in message_type.FIELDS.values():
value = d.get(field.name) value = d.get(field.name)
@ -624,8 +624,8 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
return message_type(**params) return message_type(**params)
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]: def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> dict[str, t.Any]:
def convert_value(value: Any) -> Any: def convert_value(value: t.Any) -> t.Any:
if hexlify_bytes and isinstance(value, bytes): if hexlify_bytes and isinstance(value, bytes):
return value.hex() return value.hex()
elif isinstance(value, MessageType): elif isinstance(value, MessageType):