From 27fef37cc930a31120b46871033a6755182e374f Mon Sep 17 00:00:00 2001 From: matejcik Date: Wed, 12 Jun 2024 14:11:53 +0200 Subject: [PATCH] style(python): modernize type annotations in protobuf.py --- python/src/trezorlib/protobuf.py | 56 ++++++++++++++++---------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index f0d407191a..be4050616b 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -22,23 +22,25 @@ For de-serializing (loading) protobuf types, object with `Reader` interface is r For serializing (dumping) protobuf types, object with `Writer` interface is required. """ +from __future__ import annotations + import logging import warnings from dataclasses import dataclass from enum import IntEnum from io import BytesIO from itertools import zip_longest -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +import typing as t -from typing_extensions import Protocol, TypeGuard +import typing_extensions as tx -T = TypeVar("T", bound=type) -MT = TypeVar("MT", bound="MessageType") +T = t.TypeVar("T", bound=type) +MT = t.TypeVar("MT", bound="MessageType") MAX_FIELD_SIZE = 1024 * 1024 # 1 MB -class Reader(Protocol): +class Reader(tx.Protocol): def readinto(self, __buf: bytearray) -> int: """ Reads exactly `len(buffer)` bytes into `buffer`. Returns number of bytes read, @@ -47,7 +49,7 @@ class Reader(Protocol): ... -class Writer(Protocol): +class Writer(tx.Protocol): def write(self, __buf: bytes) -> int: """ Writes all bytes from `buffer`, or raises `EOFError` @@ -60,7 +62,7 @@ _UVARINT_BUFFER = bytearray(1) LOG = logging.getLogger(__name__) -def safe_issubclass(value: Any, cls: Union[T, Tuple[T, ...]]) -> TypeGuard[T]: +def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]: return isinstance(value, type) and issubclass(value, cls) @@ -189,15 +191,15 @@ class _MessageTypeMeta(type): class MessageType(metaclass=_MessageTypeMeta): - MESSAGE_WIRE_TYPE: Optional[int] = None + MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None - FIELDS: Dict[int, Field] = {} + FIELDS: t.ClassVar[dict[int, Field]] = {} @classmethod - def get_field(cls, name: str) -> Optional[Field]: + def get_field(cls, name: str) -> Field | None: return next((f for f in cls.FIELDS.values() if f.name == name), None) - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: if args: warnings.warn( "Positional arguments for MessageType are deprecated", @@ -220,7 +222,7 @@ class MessageType(metaclass=_MessageTypeMeta): # set in args but not in kwargs setattr(self, field.name, val) else: - default: Any + default: t.Any # not set at all, pick a default if field.repeated: default = [] @@ -235,7 +237,7 @@ class MessageType(metaclass=_MessageTypeMeta): default = field.default setattr(self, field.name, default) - def __eq__(self, rhs: Any) -> bool: + def __eq__(self, rhs: t.Any) -> bool: return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ def __repr__(self) -> str: @@ -276,9 +278,7 @@ class CountingWriter: return nwritten -def get_field_type_object( - field: Field, -) -> Optional[Union[Type[MessageType], Type[IntEnum]]]: +def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None: from . import messages field_type_object = getattr(messages, field.type, None) @@ -287,7 +287,7 @@ def get_field_type_object( return field_type_object -def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]: +def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]: assert field.repeated, "Not decoding packed array into non-repeated field" length = load_uvarint(reader) packed_reader = LimitedReader(reader, length) @@ -300,7 +300,7 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]: return values -def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]: +def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum: assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded" value = load_uvarint(reader) @@ -336,7 +336,7 @@ def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnu def decode_length_delimited_field( field: Field, reader: Reader -) -> Union[bytes, str, MessageType]: +) -> bytes | str | MessageType: value = load_uvarint(reader) if value > MAX_FIELD_SIZE: raise ValueError(f"Field {field.name} contents too large ({value} bytes)") @@ -358,8 +358,8 @@ def decode_length_delimited_field( raise TypeError # field type is unknown -def load_message(reader: Reader, msg_type: Type[MT]) -> MT: - msg_dict: Dict[str, Any] = {} +def load_message(reader: Reader, msg_type: type[MT]) -> MT: + msg_dict: dict[str, t.Any] = {} # pre-seed the dict for field in msg_type.FIELDS.values(): if field.repeated: @@ -500,8 +500,8 @@ def format_message( pb: "MessageType", indent: int = 0, sep: str = " " * 4, - truncate_after: Optional[int] = 256, - truncate_to: Optional[int] = 64, + truncate_after: int | None = 256, + truncate_to: int | None = 64, ) -> str: def mostly_printable(bytes: bytes) -> bool: if not bytes: @@ -509,7 +509,7 @@ def format_message( printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E) return printable / len(bytes) > 0.8 - def pformat(name: str, value: Any, indent: int) -> str: + def pformat(name: str, value: t.Any, indent: int) -> str: level = sep * indent leadin = sep * (indent + 1) @@ -568,7 +568,7 @@ def format_message( ) -def value_to_proto(field: Field, value: Any) -> Any: +def value_to_proto(field: Field, value: t.Any) -> t.Any: field_type_object = get_field_type_object(field) if safe_issubclass(field_type_object, MessageType): raise TypeError("value_to_proto only converts simple values") @@ -601,7 +601,7 @@ def value_to_proto(field: Field, value: Any) -> Any: raise TypeError(f"can't convert {type(value)} value to bytes") -def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: +def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT: params = {} for field in message_type.FIELDS.values(): value = d.get(field.name) @@ -624,8 +624,8 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: return message_type(**params) -def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> Dict[str, Any]: - def convert_value(value: Any) -> Any: +def to_dict(msg: "MessageType", hexlify_bytes: bool = True) -> dict[str, t.Any]: + def convert_value(value: t.Any) -> t.Any: if hexlify_bytes and isinstance(value, bytes): return value.hex() elif isinstance(value, MessageType):