|
|
|
@ -10,13 +10,12 @@ import shutil
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
|
|
|
|
import tempfile
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
from collections import namedtuple, defaultdict
|
|
|
|
|
|
|
|
|
|
import attr
|
|
|
|
|
|
|
|
|
|
from google.protobuf import descriptor_pb2
|
|
|
|
|
|
|
|
|
|
ProtoField = namedtuple(
|
|
|
|
|
"ProtoField", "name, number, proto_type, py_type, repeated, required, orig"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
AUTO_HEADER = "# Automatically generated by pb2py\n"
|
|
|
|
|
|
|
|
|
@ -24,7 +23,7 @@ AUTO_HEADER = "# Automatically generated by pb2py\n"
|
|
|
|
|
FIELD_TYPES = {
|
|
|
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'),
|
|
|
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'),
|
|
|
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'),
|
|
|
|
|
# descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'),
|
|
|
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'),
|
|
|
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'),
|
|
|
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'),
|
|
|
|
@ -42,6 +41,53 @@ PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC))
|
|
|
|
|
PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@attr.s
|
|
|
|
|
class ProtoField:
|
|
|
|
|
name = attr.ib()
|
|
|
|
|
number = attr.ib()
|
|
|
|
|
orig = attr.ib()
|
|
|
|
|
repeated = attr.ib()
|
|
|
|
|
required = attr.ib()
|
|
|
|
|
type_name = attr.ib()
|
|
|
|
|
proto_type = attr.ib()
|
|
|
|
|
py_type = attr.ib()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_field(cls, descriptor, field):
|
|
|
|
|
repeated = field.label == field.LABEL_REPEATED
|
|
|
|
|
required = field.label == field.LABEL_REQUIRED
|
|
|
|
|
# ignore package path
|
|
|
|
|
type_name = field.type_name.rsplit(".")[-1]
|
|
|
|
|
|
|
|
|
|
if field.type == field.TYPE_MESSAGE:
|
|
|
|
|
proto_type = py_type = type_name
|
|
|
|
|
elif field.type == field.TYPE_ENUM:
|
|
|
|
|
valuestr = ", ".join(str(v) for v in descriptor.enum_types[type_name])
|
|
|
|
|
proto_type = 'p.EnumType("{}", ({}))'.format(type_name, valuestr)
|
|
|
|
|
py_type = "EnumType" + type_name
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
proto_type, py_type = FIELD_TYPES[field.type]
|
|
|
|
|
except KeyError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unknown field type {} for field {}".format(field.type, field.name)
|
|
|
|
|
) from None
|
|
|
|
|
|
|
|
|
|
if repeated:
|
|
|
|
|
py_type = "List[{}]".format(py_type)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
name=field.name,
|
|
|
|
|
number=field.number,
|
|
|
|
|
orig=field,
|
|
|
|
|
repeated=repeated,
|
|
|
|
|
required=required,
|
|
|
|
|
type_name=type_name,
|
|
|
|
|
proto_type=proto_type,
|
|
|
|
|
py_type=py_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def protoc(files, additional_includes=()):
|
|
|
|
|
"""Compile code with protoc and return the data."""
|
|
|
|
|
include_dirs = set()
|
|
|
|
@ -113,6 +159,7 @@ class Descriptor:
|
|
|
|
|
# find messages and enums
|
|
|
|
|
self.messages = []
|
|
|
|
|
self.enums = []
|
|
|
|
|
self.enum_types = defaultdict(set)
|
|
|
|
|
for file in self.files:
|
|
|
|
|
self.messages += file.message_type
|
|
|
|
|
self.enums += file.enum_type
|
|
|
|
@ -136,9 +183,7 @@ class Descriptor:
|
|
|
|
|
def find_message_types(self, message_type):
|
|
|
|
|
message_types = {}
|
|
|
|
|
try:
|
|
|
|
|
message_type_enum = next(
|
|
|
|
|
enum for enum in self.enums if enum.name == message_type
|
|
|
|
|
)
|
|
|
|
|
message_type_enum = next(e for e in self.enums if e.name == message_type)
|
|
|
|
|
for value in message_type_enum.value:
|
|
|
|
|
name = strip_leader(value.name, message_type)
|
|
|
|
|
message_types[name] = value.number
|
|
|
|
@ -151,38 +196,10 @@ class Descriptor:
|
|
|
|
|
|
|
|
|
|
return message_types
|
|
|
|
|
|
|
|
|
|
def parse_field(self, field):
|
|
|
|
|
repeated = field.label == field.LABEL_REPEATED
|
|
|
|
|
required = field.label == field.LABEL_REQUIRED
|
|
|
|
|
if field.type == field.TYPE_MESSAGE:
|
|
|
|
|
# ignore package path
|
|
|
|
|
type_name = field.type_name.rsplit(".")[-1]
|
|
|
|
|
proto_type = py_type = type_name
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
proto_type, py_type = FIELD_TYPES[field.type]
|
|
|
|
|
except KeyError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unknown field type {} for field {}".format(field.type, field.name)
|
|
|
|
|
) from None
|
|
|
|
|
|
|
|
|
|
if repeated:
|
|
|
|
|
py_type = "List[{}]".format(py_type)
|
|
|
|
|
|
|
|
|
|
return ProtoField(
|
|
|
|
|
name=field.name,
|
|
|
|
|
number=field.number,
|
|
|
|
|
proto_type=proto_type,
|
|
|
|
|
py_type=py_type,
|
|
|
|
|
repeated=repeated,
|
|
|
|
|
required=required,
|
|
|
|
|
orig=field,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def create_message_import(self, name):
|
|
|
|
|
return "from .{0} import {0}".format(name)
|
|
|
|
|
|
|
|
|
|
def process_message_imports(self, fields):
|
|
|
|
|
def process_subtype_imports(self, fields):
|
|
|
|
|
imports = set(
|
|
|
|
|
field.proto_type
|
|
|
|
|
for field in fields
|
|
|
|
@ -251,16 +268,26 @@ class Descriptor:
|
|
|
|
|
# "from .. import protobuf as p"
|
|
|
|
|
yield self.protobuf_import + " as p"
|
|
|
|
|
|
|
|
|
|
fields = [self.parse_field(field) for field in message.field]
|
|
|
|
|
fields = [ProtoField.from_field(self, field) for field in message.field]
|
|
|
|
|
|
|
|
|
|
yield from self.process_message_imports(fields)
|
|
|
|
|
yield from self.process_subtype_imports(fields)
|
|
|
|
|
|
|
|
|
|
yield ""
|
|
|
|
|
yield "if __debug__:"
|
|
|
|
|
yield " try:"
|
|
|
|
|
yield " from typing import Dict, List, Optional"
|
|
|
|
|
yield " from typing_extensions import Literal # noqa: F401"
|
|
|
|
|
|
|
|
|
|
all_enums = [field for field in fields if field.type_name in self.enum_types]
|
|
|
|
|
for field in all_enums:
|
|
|
|
|
allowed_values = self.enum_types[field.type_name]
|
|
|
|
|
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
|
|
|
|
|
yield " {} = Literal[{}]".format(field.py_type, valuestr)
|
|
|
|
|
|
|
|
|
|
yield " except ImportError:"
|
|
|
|
|
yield " Dict, List, Optional = None, None, None # type: ignore"
|
|
|
|
|
for field in all_enums:
|
|
|
|
|
yield " {} = None # type: ignore".format(field.py_type)
|
|
|
|
|
|
|
|
|
|
yield ""
|
|
|
|
|
yield ""
|
|
|
|
@ -294,6 +321,7 @@ class Descriptor:
|
|
|
|
|
enum_prefix, _ = enum_prefix.rsplit("Type", 1)
|
|
|
|
|
name = strip_leader(name, enum_prefix)
|
|
|
|
|
|
|
|
|
|
self.enum_types[enum.name].add(value.number)
|
|
|
|
|
yield "{} = {}".format(name, value.number)
|
|
|
|
|
|
|
|
|
|
def process_messages(self, messages):
|
|
|
|
@ -325,8 +353,8 @@ class Descriptor:
|
|
|
|
|
|
|
|
|
|
def write_classes(self, out_dir, init_py=True):
|
|
|
|
|
self.out_dir = out_dir
|
|
|
|
|
self.process_messages(self.messages)
|
|
|
|
|
self.process_enums(self.enums)
|
|
|
|
|
self.process_messages(self.messages)
|
|
|
|
|
if init_py:
|
|
|
|
|
self.write_init_py()
|
|
|
|
|
|
|
|
|
|