From 8847c58bbfe63afada2de2571877551aab9e9273 Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 14 Sep 2020 12:48:03 +0200 Subject: [PATCH] feat(common): allow overriding wire_type of a generated message --- common/protob/messages.proto | 1 + common/protob/pb2py | 41 +++++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/common/protob/messages.proto b/common/protob/messages.proto index 9808717be9..a4f922b50f 100644 --- a/common/protob/messages.proto +++ b/common/protob/messages.proto @@ -28,6 +28,7 @@ extend google.protobuf.EnumValueOptions { /** Options for tagging message types */ extend google.protobuf.MessageOptions { optional bool unstable = 50001; // indicate that a message definition might change at any time + optional uint32 wire_type = 50002; // override wire type specified in the MessageType enum } diff --git a/common/protob/pb2py b/common/protob/pb2py index a21b90bbf2..09feef39c8 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -13,6 +13,7 @@ import tempfile from collections import namedtuple, defaultdict import attr +import construct as c from google.protobuf import descriptor_pb2 @@ -32,6 +33,22 @@ FIELD_TYPES = { } # fmt: on +ListOfSimpleValues = c.GreedyRange( + c.Struct( + "key" / c.VarInt, + "value" / c.VarInt, + ) +) + + +def parse_protobuf_simple(data): + """Micro-parse protobuf-encoded data. + + Assume every field is of type 0 (varint), and parse to a dict of fieldnum: value. + """ + return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)} + + PROTOC = shutil.which("protoc") if not PROTOC: print("protoc command not found") @@ -177,11 +194,16 @@ class Descriptor: self.messages = [] self.enums = [] self.enum_types = defaultdict(dict) + + self.extensions = {} + for file in self.files: self.messages += file.message_type self.enums += file.enum_type for message in file.message_type: self._nested_types_from_message(message) + for extension in file.extension: + self.extensions[extension.name] = extension.number if not self.messages and not self.enums: raise RuntimeError("No messages and no enums found.") @@ -191,6 +213,20 @@ class Descriptor: self.out_dir = None + def _get_extension(self, something, extension_name, default=None): + # There doesn't seem to be a sane way to access extensions on a descriptor + # via the google.protobuf API. + # We do have access to descriptors of the extensions... + extension_num = self.extensions[extension_name] + # ...and the "options" descriptor _does_ include the extension data. But while + # the API provides access to unknown fields, it hides the extensions. + # What we do is re-encode the options descriptor... + options_bytes = something.options.SerializeToString() + # ...and re-parse it as a dict of uvarints... + simple_values = parse_protobuf_simple(options_bytes) + # ...and extract the value corresponding to the extension we care about. + return simple_values.get(extension_num, default) + def _nested_types_from_message(self, message): self.messages += message.nested_type self.enums += message.enum_type @@ -286,7 +322,10 @@ class Descriptor: def process_message(self, message, include_deprecated=False): logging.debug("Processing message {}".format(message.name)) - msg_id = self.message_types.get(message.name) + + msg_id = self._get_extension(message, "wire_type") + if msg_id is None: + msg_id = self.message_types.get(message.name) # "from .. import protobuf as p" yield self.protobuf_import + " as p"