|
|
|
@ -24,6 +24,7 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from itertools import zip_longest
|
|
|
|
|
from typing import (
|
|
|
|
|
Any,
|
|
|
|
|
Callable,
|
|
|
|
@ -37,6 +38,7 @@ from typing import (
|
|
|
|
|
TypeVar,
|
|
|
|
|
Union,
|
|
|
|
|
)
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
from typing_extensions import Protocol
|
|
|
|
|
|
|
|
|
@ -198,11 +200,29 @@ class UnicodeType:
|
|
|
|
|
WIRE_TYPE = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MessageType:
|
|
|
|
|
class _MessageTypeMeta(type):
|
|
|
|
|
def __init__(cls, name, bases, d) -> None:
|
|
|
|
|
super().__init__(name, bases, d)
|
|
|
|
|
if name != "MessageType":
|
|
|
|
|
cls.__init__ = MessageType.__init__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MessageType(metaclass=_MessageTypeMeta):
|
|
|
|
|
WIRE_TYPE = 2
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_fields(cls) -> Dict[int, FieldInfo]:
|
|
|
|
|
"""Return a field descriptor.
|
|
|
|
|
|
|
|
|
|
The descriptor is a mapping:
|
|
|
|
|
field_id -> (field_name, field_type, default_value)
|
|
|
|
|
|
|
|
|
|
`default_value` can also be one of the special values:
|
|
|
|
|
* `FLAG_REQUIRED` indicates that the field value has no default and _must_ be
|
|
|
|
|
provided by caller/sender.
|
|
|
|
|
* `FLAG_REPEATED` indicates that the field is a list of `field_type` values. In
|
|
|
|
|
that case the default value is an empty list.
|
|
|
|
|
"""
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@ -212,10 +232,42 @@ class MessageType:
|
|
|
|
|
return ftype
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
|
|
|
for kw in kwargs:
|
|
|
|
|
setattr(self, kw, kwargs[kw])
|
|
|
|
|
self._fill_missing()
|
|
|
|
|
def __init__(self, *args, **kwargs: Any) -> None:
|
|
|
|
|
fields = self.get_fields()
|
|
|
|
|
if args:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Positional arguments for MessageType are deprecated",
|
|
|
|
|
DeprecationWarning,
|
|
|
|
|
stacklevel=2,
|
|
|
|
|
)
|
|
|
|
|
# process fields one by one
|
|
|
|
|
NOT_PROVIDED = object()
|
|
|
|
|
for field, val in zip_longest(fields.values(), args, fillvalue=NOT_PROVIDED):
|
|
|
|
|
if field is NOT_PROVIDED:
|
|
|
|
|
raise TypeError("too many positional arguments")
|
|
|
|
|
fname, _, fdefault = field
|
|
|
|
|
if fname in kwargs and val is not NOT_PROVIDED:
|
|
|
|
|
# both *args and **kwargs specify the same thing
|
|
|
|
|
raise TypeError(f"got multiple values for argument '{fname}'")
|
|
|
|
|
elif fname in kwargs:
|
|
|
|
|
# set in kwargs but not in args
|
|
|
|
|
setattr(self, fname, kwargs[fname])
|
|
|
|
|
elif val is not NOT_PROVIDED:
|
|
|
|
|
# set in args but not in kwargs
|
|
|
|
|
setattr(self, fname, val)
|
|
|
|
|
else:
|
|
|
|
|
# not set at all, pick a default
|
|
|
|
|
if fdefault is FLAG_REPEATED:
|
|
|
|
|
fdefault = []
|
|
|
|
|
elif fdefault is FLAG_EXPERIMENTAL:
|
|
|
|
|
fdefault = None
|
|
|
|
|
elif fdefault is FLAG_REQUIRED:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Value of required field '{fname}' must be provided in constructor",
|
|
|
|
|
DeprecationWarning,
|
|
|
|
|
stacklevel=2,
|
|
|
|
|
)
|
|
|
|
|
setattr(self, fname, fdefault)
|
|
|
|
|
|
|
|
|
|
def __eq__(self, rhs: Any) -> bool:
|
|
|
|
|
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
|
|
|
@ -237,17 +289,6 @@ class MessageType:
|
|
|
|
|
def __getitem__(self, key: str) -> Any:
|
|
|
|
|
return getattr(self, key)
|
|
|
|
|
|
|
|
|
|
def _fill_missing(self) -> None:
|
|
|
|
|
# fill missing fields
|
|
|
|
|
for fname, _, fdefault in self.get_fields().values():
|
|
|
|
|
if not hasattr(self, fname):
|
|
|
|
|
if fdefault is FLAG_REPEATED:
|
|
|
|
|
setattr(self, fname, [])
|
|
|
|
|
elif fdefault is FLAG_REQUIRED:
|
|
|
|
|
raise ValueError("value for required field is missing")
|
|
|
|
|
else:
|
|
|
|
|
setattr(self, fname, fdefault)
|
|
|
|
|
|
|
|
|
|
def ByteSize(self) -> int:
|
|
|
|
|
data = BytesIO()
|
|
|
|
|
dump_message(data, self)
|
|
|
|
@ -403,6 +444,8 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|
|
|
|
fvalue = getattr(msg, fname, None)
|
|
|
|
|
if fvalue is None:
|
|
|
|
|
continue
|
|
|
|
|
if fvalue is FLAG_REQUIRED:
|
|
|
|
|
raise ValueError # required value was not provided
|
|
|
|
|
|
|
|
|
|
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
|
|
|
|
|
|
|
|
|