1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-21 23:18:13 +00:00

style(python): modernize type annotations in protobuf.py

This commit is contained in:
matejcik 2024-06-12 14:11:53 +02:00
parent eaeb58fb25
commit 27fef37cc9

View File

@ -22,23 +22,25 @@ For de-serializing (loading) protobuf types, object with `Reader` interface is r
For serializing (dumping) protobuf types, object with `Writer` interface is required.
"""
from __future__ import annotations
import logging
import warnings
from dataclasses import dataclass
from enum import IntEnum
from io import BytesIO
from itertools import zip_longest
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import typing as t
from typing_extensions import Protocol, TypeGuard
import typing_extensions as tx
T = TypeVar("T", bound=type)
MT = TypeVar("MT", bound="MessageType")
T = t.TypeVar("T", bound=type)
MT = t.TypeVar("MT", bound="MessageType")
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB
class Reader(Protocol):
class Reader(tx.Protocol):
def readinto(self, __buf: bytearray) -> int:
"""
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
@ -47,7 +49,7 @@ class Reader(Protocol):
...
class Writer(Protocol):
class Writer(tx.Protocol):
def write(self, __buf: bytes) -> int:
"""
Writes all bytes from `buffer`, or raises `EOFError`
@ -60,7 +62,7 @@ _UVARINT_BUFFER = bytearray(1)
LOG = logging.getLogger(__name__)
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
return isinstance(value, type) and issubclass(value, cls)
@ -189,15 +191,15 @@ class _MessageTypeMeta(type):
class MessageType(metaclass=_MessageTypeMeta):
MESSAGE_WIRE_TYPE: Optional[int] = None
MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None
FIELDS: Dict[int, Field] = {}
FIELDS: t.ClassVar[dict[int, Field]] = {}
@classmethod
def get_field(cls, name: str) -> Optional[Field]:
def get_field(cls, name: str) -> Field | None:
return next((f for f in cls.FIELDS.values() if f.name == name), None)
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
if args:
warnings.warn(
"Positional arguments for MessageType are deprecated",
@ -220,7 +222,7 @@ class MessageType(metaclass=_MessageTypeMeta):
# set in args but not in kwargs
setattr(self, field.name, val)
else:
default: Any
default: t.Any
# not set at all, pick a default
if field.repeated:
default = []
@ -235,7 +237,7 @@ class MessageType(metaclass=_MessageTypeMeta):
default = field.default
setattr(self, field.name, default)
def __eq__(self, rhs: Any) -> bool:
def __eq__(self, rhs: t.Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self) -> str:
@ -276,9 +278,7 @@ class CountingWriter:
return nwritten
def get_field_type_object(
field: Field,
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
from . import messages
field_type_object = getattr(messages, field.type, None)
@ -287,7 +287,7 @@ def get_field_type_object(
return field_type_object
def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
assert field.repeated, "Not decoding packed array into non-repeated field"
length = load_uvarint(reader)
packed_reader = LimitedReader(reader, length)
@ -300,7 +300,7 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
return values
def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]:
def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
value = load_uvarint(reader)
@ -336,7 +336,7 @@ def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnu
def decode_length_delimited_field(
field: Field, reader: Reader
) -> Union[bytes, str, MessageType]:
) -> bytes | str | MessageType:
value = load_uvarint(reader)
if value > MAX_FIELD_SIZE:
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
@ -358,8 +358,8 @@ def decode_length_delimited_field(
raise TypeError # field type is unknown
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
msg_dict: Dict[str, Any] = {}
def load_message(reader: Reader, msg_type: type[MT]) -> MT:
msg_dict: dict[str, t.Any] = {}
# pre-seed the dict
for field in msg_type.FIELDS.values():
if field.repeated:
@ -500,8 +500,8 @@ def format_message(
pb: "MessageType",
indent: int = 0,
sep: str = " " * 4,
truncate_after: Optional[int] = 256,
truncate_to: Optional[int] = 64,
truncate_after: int | None = 256,
truncate_to: int | None = 64,
) -> str:
def mostly_printable(bytes: bytes) -> bool:
if not bytes:
@ -509,7 +509,7 @@ def format_message(
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
return printable / len(bytes) > 0.8
def pformat(name: str, value: Any, indent: int) -> str:
def pformat(name: str, value: t.Any, indent: int) -> str:
level = sep * indent
leadin = sep * (indent + 1)
@ -568,7 +568,7 @@ def format_message(
)
def value_to_proto(field: Field, value: Any) -> Any:
def value_to_proto(field: Field, value: t.Any) -> t.Any:
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
raise TypeError("value_to_proto only converts simple values")
@ -601,7 +601,7 @@ def value_to_proto(field: Field, value: Any) -> Any:
raise TypeError(f"can't convert {type(value)} value to bytes")
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
params = {}
for field in message_type.FIELDS.values():
value = d.get(field.name)
@ -624,8 +624,8 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
return message_type(**params)
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
def convert_value(value: Any) -> Any:
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> dict[str, t.Any]:
def convert_value(value: t.Any) -> t.Any:
if hexlify_bytes and isinstance(value, bytes):
return value.hex()
elif isinstance(value, MessageType):