mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-21 23:18:13 +00:00
style(python): modernize type annotations in protobuf.py
This commit is contained in:
parent
eaeb58fb25
commit
27fef37cc9
@ -22,23 +22,25 @@ For de-serializing (loading) protobuf types, object with `Reader` interface is r
|
||||
For serializing (dumping) protobuf types, object with `Writer` interface is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
from itertools import zip_longest
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
import typing as t
|
||||
|
||||
from typing_extensions import Protocol, TypeGuard
|
||||
import typing_extensions as tx
|
||||
|
||||
T = TypeVar("T", bound=type)
|
||||
MT = TypeVar("MT", bound="MessageType")
|
||||
T = t.TypeVar("T", bound=type)
|
||||
MT = t.TypeVar("MT", bound="MessageType")
|
||||
|
||||
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB
|
||||
|
||||
|
||||
class Reader(Protocol):
|
||||
class Reader(tx.Protocol):
|
||||
def readinto(self, __buf: bytearray) -> int:
|
||||
"""
|
||||
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
||||
@ -47,7 +49,7 @@ class Reader(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class Writer(Protocol):
|
||||
class Writer(tx.Protocol):
|
||||
def write(self, __buf: bytes) -> int:
|
||||
"""
|
||||
Writes all bytes from `buffer`, or raises `EOFError`
|
||||
@ -60,7 +62,7 @@ _UVARINT_BUFFER = bytearray(1)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
|
||||
def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
|
||||
return isinstance(value, type) and issubclass(value, cls)
|
||||
|
||||
|
||||
@ -189,15 +191,15 @@ class _MessageTypeMeta(type):
|
||||
|
||||
|
||||
class MessageType(metaclass=_MessageTypeMeta):
|
||||
MESSAGE_WIRE_TYPE: Optional[int] = None
|
||||
MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None
|
||||
|
||||
FIELDS: Dict[int, Field] = {}
|
||||
FIELDS: t.ClassVar[dict[int, Field]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_field(cls, name: str) -> Optional[Field]:
|
||||
def get_field(cls, name: str) -> Field | None:
|
||||
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
if args:
|
||||
warnings.warn(
|
||||
"Positional arguments for MessageType are deprecated",
|
||||
@ -220,7 +222,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
||||
# set in args but not in kwargs
|
||||
setattr(self, field.name, val)
|
||||
else:
|
||||
default: Any
|
||||
default: t.Any
|
||||
# not set at all, pick a default
|
||||
if field.repeated:
|
||||
default = []
|
||||
@ -235,7 +237,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
||||
default = field.default
|
||||
setattr(self, field.name, default)
|
||||
|
||||
def __eq__(self, rhs: Any) -> bool:
|
||||
def __eq__(self, rhs: t.Any) -> bool:
|
||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -276,9 +278,7 @@ class CountingWriter:
|
||||
return nwritten
|
||||
|
||||
|
||||
def get_field_type_object(
|
||||
field: Field,
|
||||
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
|
||||
def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
|
||||
from . import messages
|
||||
|
||||
field_type_object = getattr(messages, field.type, None)
|
||||
@ -287,7 +287,7 @@ def get_field_type_object(
|
||||
return field_type_object
|
||||
|
||||
|
||||
def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
|
||||
def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
|
||||
assert field.repeated, "Not decoding packed array into non-repeated field"
|
||||
length = load_uvarint(reader)
|
||||
packed_reader = LimitedReader(reader, length)
|
||||
@ -300,7 +300,7 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
|
||||
return values
|
||||
|
||||
|
||||
def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]:
|
||||
def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
|
||||
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
|
||||
value = load_uvarint(reader)
|
||||
|
||||
@ -336,7 +336,7 @@ def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnu
|
||||
|
||||
def decode_length_delimited_field(
|
||||
field: Field, reader: Reader
|
||||
) -> Union[bytes, str, MessageType]:
|
||||
) -> bytes | str | MessageType:
|
||||
value = load_uvarint(reader)
|
||||
if value > MAX_FIELD_SIZE:
|
||||
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
|
||||
@ -358,8 +358,8 @@ def decode_length_delimited_field(
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
|
||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
msg_dict: Dict[str, Any] = {}
|
||||
def load_message(reader: Reader, msg_type: type[MT]) -> MT:
|
||||
msg_dict: dict[str, t.Any] = {}
|
||||
# pre-seed the dict
|
||||
for field in msg_type.FIELDS.values():
|
||||
if field.repeated:
|
||||
@ -500,8 +500,8 @@ def format_message(
|
||||
pb: "MessageType",
|
||||
indent: int = 0,
|
||||
sep: str = " " * 4,
|
||||
truncate_after: Optional[int] = 256,
|
||||
truncate_to: Optional[int] = 64,
|
||||
truncate_after: int | None = 256,
|
||||
truncate_to: int | None = 64,
|
||||
) -> str:
|
||||
def mostly_printable(bytes: bytes) -> bool:
|
||||
if not bytes:
|
||||
@ -509,7 +509,7 @@ def format_message(
|
||||
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
|
||||
return printable / len(bytes) > 0.8
|
||||
|
||||
def pformat(name: str, value: Any, indent: int) -> str:
|
||||
def pformat(name: str, value: t.Any, indent: int) -> str:
|
||||
level = sep * indent
|
||||
leadin = sep * (indent + 1)
|
||||
|
||||
@ -568,7 +568,7 @@ def format_message(
|
||||
)
|
||||
|
||||
|
||||
def value_to_proto(field: Field, value: Any) -> Any:
|
||||
def value_to_proto(field: Field, value: t.Any) -> t.Any:
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
raise TypeError("value_to_proto only converts simple values")
|
||||
@ -601,7 +601,7 @@ def value_to_proto(field: Field, value: Any) -> Any:
|
||||
raise TypeError(f"can't convert {type(value)} value to bytes")
|
||||
|
||||
|
||||
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||
def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
|
||||
params = {}
|
||||
for field in message_type.FIELDS.values():
|
||||
value = d.get(field.name)
|
||||
@ -624,8 +624,8 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||
return message_type(**params)
|
||||
|
||||
|
||||
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||
def convert_value(value: Any) -> Any:
|
||||
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> dict[str, t.Any]:
|
||||
def convert_value(value: t.Any) -> t.Any:
|
||||
if hexlify_bytes and isinstance(value, bytes):
|
||||
return value.hex()
|
||||
elif isinstance(value, MessageType):
|
||||
|
Loading…
Reference in New Issue
Block a user