mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-27 07:40:59 +00:00
feat(common): allow overriding wire_type of a generated message
This commit is contained in:
parent
90ee5f3d38
commit
8847c58bbf
@ -28,6 +28,7 @@ extend google.protobuf.EnumValueOptions {
|
||||
/** Options for tagging message types */
|
||||
extend google.protobuf.MessageOptions {
|
||||
optional bool unstable = 50001; // indicate that a message definition might change at any time
|
||||
optional uint32 wire_type = 50002; // override wire type specified in the MessageType enum
|
||||
}
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ import tempfile
|
||||
from collections import namedtuple, defaultdict
|
||||
|
||||
import attr
|
||||
import construct as c
|
||||
|
||||
from google.protobuf import descriptor_pb2
|
||||
|
||||
@ -32,6 +33,22 @@ FIELD_TYPES = {
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
ListOfSimpleValues = c.GreedyRange(
|
||||
c.Struct(
|
||||
"key" / c.VarInt,
|
||||
"value" / c.VarInt,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def parse_protobuf_simple(data):
|
||||
"""Micro-parse protobuf-encoded data.
|
||||
|
||||
Assume every field is of type 0 (varint), and parse to a dict of fieldnum: value.
|
||||
"""
|
||||
return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)}
|
||||
|
||||
|
||||
PROTOC = shutil.which("protoc")
|
||||
if not PROTOC:
|
||||
print("protoc command not found")
|
||||
@ -177,11 +194,16 @@ class Descriptor:
|
||||
self.messages = []
|
||||
self.enums = []
|
||||
self.enum_types = defaultdict(dict)
|
||||
|
||||
self.extensions = {}
|
||||
|
||||
for file in self.files:
|
||||
self.messages += file.message_type
|
||||
self.enums += file.enum_type
|
||||
for message in file.message_type:
|
||||
self._nested_types_from_message(message)
|
||||
for extension in file.extension:
|
||||
self.extensions[extension.name] = extension.number
|
||||
|
||||
if not self.messages and not self.enums:
|
||||
raise RuntimeError("No messages and no enums found.")
|
||||
@ -191,6 +213,20 @@ class Descriptor:
|
||||
|
||||
self.out_dir = None
|
||||
|
||||
def _get_extension(self, something, extension_name, default=None):
|
||||
# There doesn't seem to be a sane way to access extensions on a descriptor
|
||||
# via the google.protobuf API.
|
||||
# We do have access to descriptors of the extensions...
|
||||
extension_num = self.extensions[extension_name]
|
||||
# ...and the "options" descriptor _does_ include the extension data. But while
|
||||
# the API provides access to unknown fields, it hides the extensions.
|
||||
# What we do is re-encode the options descriptor...
|
||||
options_bytes = something.options.SerializeToString()
|
||||
# ...and re-parse it as a dict of uvarints...
|
||||
simple_values = parse_protobuf_simple(options_bytes)
|
||||
# ...and extract the value corresponding to the extension we care about.
|
||||
return simple_values.get(extension_num, default)
|
||||
|
||||
def _nested_types_from_message(self, message):
|
||||
self.messages += message.nested_type
|
||||
self.enums += message.enum_type
|
||||
@ -286,7 +322,10 @@ class Descriptor:
|
||||
|
||||
def process_message(self, message, include_deprecated=False):
|
||||
logging.debug("Processing message {}".format(message.name))
|
||||
msg_id = self.message_types.get(message.name)
|
||||
|
||||
msg_id = self._get_extension(message, "wire_type")
|
||||
if msg_id is None:
|
||||
msg_id = self.message_types.get(message.name)
|
||||
|
||||
# "from .. import protobuf as p"
|
||||
yield self.protobuf_import + " as p"
|
||||
|
Loading…
Reference in New Issue
Block a user