mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 07:28:10 +00:00
style(python): modernize type annotations in protobuf.py
This commit is contained in:
parent
eaeb58fb25
commit
27fef37cc9
@ -22,23 +22,25 @@ For de-serializing (loading) protobuf types, object with `Reader` interface is r
|
|||||||
For serializing (dumping) protobuf types, object with `Writer` interface is required.
|
For serializing (dumping) protobuf types, object with `Writer` interface is required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
import typing as t
|
||||||
|
|
||||||
from typing_extensions import Protocol, TypeGuard
|
import typing_extensions as tx
|
||||||
|
|
||||||
T = TypeVar("T", bound=type)
|
T = t.TypeVar("T", bound=type)
|
||||||
MT = TypeVar("MT", bound="MessageType")
|
MT = t.TypeVar("MT", bound="MessageType")
|
||||||
|
|
||||||
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB
|
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB
|
||||||
|
|
||||||
|
|
||||||
class Reader(Protocol):
|
class Reader(tx.Protocol):
|
||||||
def readinto(self, __buf: bytearray) -> int:
|
def readinto(self, __buf: bytearray) -> int:
|
||||||
"""
|
"""
|
||||||
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read,
|
||||||
@ -47,7 +49,7 @@ class Reader(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class Writer(Protocol):
|
class Writer(tx.Protocol):
|
||||||
def write(self, __buf: bytes) -> int:
|
def write(self, __buf: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
Writes all bytes from `buffer`, or raises `EOFError`
|
Writes all bytes from `buffer`, or raises `EOFError`
|
||||||
@ -60,7 +62,7 @@ _UVARINT_BUFFER = bytearray(1)
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]:
|
def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
|
||||||
return isinstance(value, type) and issubclass(value, cls)
|
return isinstance(value, type) and issubclass(value, cls)
|
||||||
|
|
||||||
|
|
||||||
@ -189,15 +191,15 @@ class _MessageTypeMeta(type):
|
|||||||
|
|
||||||
|
|
||||||
class MessageType(metaclass=_MessageTypeMeta):
|
class MessageType(metaclass=_MessageTypeMeta):
|
||||||
MESSAGE_WIRE_TYPE: Optional[int] = None
|
MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None
|
||||||
|
|
||||||
FIELDS: Dict[int, Field] = {}
|
FIELDS: t.ClassVar[dict[int, Field]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_field(cls, name: str) -> Optional[Field]:
|
def get_field(cls, name: str) -> Field | None:
|
||||||
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
|
||||||
if args:
|
if args:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Positional arguments for MessageType are deprecated",
|
"Positional arguments for MessageType are deprecated",
|
||||||
@ -220,7 +222,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
|||||||
# set in args but not in kwargs
|
# set in args but not in kwargs
|
||||||
setattr(self, field.name, val)
|
setattr(self, field.name, val)
|
||||||
else:
|
else:
|
||||||
default: Any
|
default: t.Any
|
||||||
# not set at all, pick a default
|
# not set at all, pick a default
|
||||||
if field.repeated:
|
if field.repeated:
|
||||||
default = []
|
default = []
|
||||||
@ -235,7 +237,7 @@ class MessageType(metaclass=_MessageTypeMeta):
|
|||||||
default = field.default
|
default = field.default
|
||||||
setattr(self, field.name, default)
|
setattr(self, field.name, default)
|
||||||
|
|
||||||
def __eq__(self, rhs: Any) -> bool:
|
def __eq__(self, rhs: t.Any) -> bool:
|
||||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -276,9 +278,7 @@ class CountingWriter:
|
|||||||
return nwritten
|
return nwritten
|
||||||
|
|
||||||
|
|
||||||
def get_field_type_object(
|
def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
|
||||||
field: Field,
|
|
||||||
) -> Optional[Union[Type[MessageType], Type[IntEnum]]]:
|
|
||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
field_type_object = getattr(messages, field.type, None)
|
field_type_object = getattr(messages, field.type, None)
|
||||||
@ -287,7 +287,7 @@ def get_field_type_object(
|
|||||||
return field_type_object
|
return field_type_object
|
||||||
|
|
||||||
|
|
||||||
def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
|
def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
|
||||||
assert field.repeated, "Not decoding packed array into non-repeated field"
|
assert field.repeated, "Not decoding packed array into non-repeated field"
|
||||||
length = load_uvarint(reader)
|
length = load_uvarint(reader)
|
||||||
packed_reader = LimitedReader(reader, length)
|
packed_reader = LimitedReader(reader, length)
|
||||||
@ -300,7 +300,7 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]:
|
def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
|
||||||
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
|
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
|
||||||
value = load_uvarint(reader)
|
value = load_uvarint(reader)
|
||||||
|
|
||||||
@ -336,7 +336,7 @@ def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnu
|
|||||||
|
|
||||||
def decode_length_delimited_field(
|
def decode_length_delimited_field(
|
||||||
field: Field, reader: Reader
|
field: Field, reader: Reader
|
||||||
) -> Union[bytes, str, MessageType]:
|
) -> bytes | str | MessageType:
|
||||||
value = load_uvarint(reader)
|
value = load_uvarint(reader)
|
||||||
if value > MAX_FIELD_SIZE:
|
if value > MAX_FIELD_SIZE:
|
||||||
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
|
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
|
||||||
@ -358,8 +358,8 @@ def decode_length_delimited_field(
|
|||||||
raise TypeError # field type is unknown
|
raise TypeError # field type is unknown
|
||||||
|
|
||||||
|
|
||||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
def load_message(reader: Reader, msg_type: type[MT]) -> MT:
|
||||||
msg_dict: Dict[str, Any] = {}
|
msg_dict: dict[str, t.Any] = {}
|
||||||
# pre-seed the dict
|
# pre-seed the dict
|
||||||
for field in msg_type.FIELDS.values():
|
for field in msg_type.FIELDS.values():
|
||||||
if field.repeated:
|
if field.repeated:
|
||||||
@ -500,8 +500,8 @@ def format_message(
|
|||||||
pb: "MessageType",
|
pb: "MessageType",
|
||||||
indent: int = 0,
|
indent: int = 0,
|
||||||
sep: str = " " * 4,
|
sep: str = " " * 4,
|
||||||
truncate_after: Optional[int] = 256,
|
truncate_after: int | None = 256,
|
||||||
truncate_to: Optional[int] = 64,
|
truncate_to: int | None = 64,
|
||||||
) -> str:
|
) -> str:
|
||||||
def mostly_printable(bytes: bytes) -> bool:
|
def mostly_printable(bytes: bytes) -> bool:
|
||||||
if not bytes:
|
if not bytes:
|
||||||
@ -509,7 +509,7 @@ def format_message(
|
|||||||
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
|
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
|
||||||
return printable / len(bytes) > 0.8
|
return printable / len(bytes) > 0.8
|
||||||
|
|
||||||
def pformat(name: str, value: Any, indent: int) -> str:
|
def pformat(name: str, value: t.Any, indent: int) -> str:
|
||||||
level = sep * indent
|
level = sep * indent
|
||||||
leadin = sep * (indent + 1)
|
leadin = sep * (indent + 1)
|
||||||
|
|
||||||
@ -568,7 +568,7 @@ def format_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def value_to_proto(field: Field, value: Any) -> Any:
|
def value_to_proto(field: Field, value: t.Any) -> t.Any:
|
||||||
field_type_object = get_field_type_object(field)
|
field_type_object = get_field_type_object(field)
|
||||||
if safe_issubclass(field_type_object, MessageType):
|
if safe_issubclass(field_type_object, MessageType):
|
||||||
raise TypeError("value_to_proto only converts simple values")
|
raise TypeError("value_to_proto only converts simple values")
|
||||||
@ -601,7 +601,7 @@ def value_to_proto(field: Field, value: Any) -> Any:
|
|||||||
raise TypeError(f"can't convert {type(value)} value to bytes")
|
raise TypeError(f"can't convert {type(value)} value to bytes")
|
||||||
|
|
||||||
|
|
||||||
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
|
||||||
params = {}
|
params = {}
|
||||||
for field in message_type.FIELDS.values():
|
for field in message_type.FIELDS.values():
|
||||||
value = d.get(field.name)
|
value = d.get(field.name)
|
||||||
@ -624,8 +624,8 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
|||||||
return message_type(**params)
|
return message_type(**params)
|
||||||
|
|
||||||
|
|
||||||
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]:
|
def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> dict[str, t.Any]:
|
||||||
def convert_value(value: Any) -> Any:
|
def convert_value(value: t.Any) -> t.Any:
|
||||||
if hexlify_bytes and isinstance(value, bytes):
|
if hexlify_bytes and isinstance(value, bytes):
|
||||||
return value.hex()
|
return value.hex()
|
||||||
elif isinstance(value, MessageType):
|
elif isinstance(value, MessageType):
|
||||||
|
Loading…
Reference in New Issue
Block a user