feat(common): allow overriding wire_type of a generated message

pull/1279/head
matejcik 4 years ago committed by matejcik
parent 90ee5f3d38
commit 8847c58bbf

@ -28,6 +28,7 @@ extend google.protobuf.EnumValueOptions {
/** Options for tagging message types */
extend google.protobuf.MessageOptions {
optional bool unstable = 50001; // indicate that a message definition might change at any time
optional uint32 wire_type = 50002; // override wire type specified in the MessageType enum
}

@ -13,6 +13,7 @@ import tempfile
from collections import namedtuple, defaultdict
import attr
import construct as c
from google.protobuf import descriptor_pb2
@ -32,6 +33,22 @@ FIELD_TYPES = {
}
# fmt: on
ListOfSimpleValues = c.GreedyRange(
c.Struct(
"key" / c.VarInt,
"value" / c.VarInt,
)
)
def parse_protobuf_simple(data):
"""Micro-parse protobuf-encoded data.
Assume every field is of type 0 (varint), and parse to a dict of fieldnum: value.
"""
return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)}
PROTOC = shutil.which("protoc")
if not PROTOC:
print("protoc command not found")
@ -177,11 +194,16 @@ class Descriptor:
self.messages = []
self.enums = []
self.enum_types = defaultdict(dict)
self.extensions = {}
for file in self.files:
self.messages += file.message_type
self.enums += file.enum_type
for message in file.message_type:
self._nested_types_from_message(message)
for extension in file.extension:
self.extensions[extension.name] = extension.number
if not self.messages and not self.enums:
raise RuntimeError("No messages and no enums found.")
@ -191,6 +213,20 @@ class Descriptor:
self.out_dir = None
def _get_extension(self, something, extension_name, default=None):
# There doesn't seem to be a sane way to access extensions on a descriptor
# via the google.protobuf API.
# We do have access to descriptors of the extensions...
extension_num = self.extensions[extension_name]
# ...and the "options" descriptor _does_ include the extension data. But while
# the API provides access to unknown fields, it hides the extensions.
# What we do is re-encode the options descriptor...
options_bytes = something.options.SerializeToString()
# ...and re-parse it as a dict of uvarints...
simple_values = parse_protobuf_simple(options_bytes)
# ...and extract the value corresponding to the extension we care about.
return simple_values.get(extension_num, default)
def _nested_types_from_message(self, message):
self.messages += message.nested_type
self.enums += message.enum_type
@ -286,7 +322,10 @@ class Descriptor:
def process_message(self, message, include_deprecated=False):
logging.debug("Processing message {}".format(message.name))
msg_id = self.message_types.get(message.name)
msg_id = self._get_extension(message, "wire_type")
if msg_id is None:
msg_id = self.message_types.get(message.name)
# "from .. import protobuf as p"
yield self.protobuf_import + " as p"

Loading…
Cancel
Save