mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
protob: add a smart pb2py builder
Now we don't need build_protobuf anymore and this is usable by both core and python-trezor (as well as generating custom protobufs from other sources) We still need protoc, unfortunately, but pb2py now calls it by itself. (little more robustly; instead of generated python classes, it uses the FileDescriptorSet output which is parsable by a built-in protobuf class) To support the script, messages.proto and types.proto must set a common package. Also there is currently no support for compiling more than one proto file, we depend on the fact that messages.proto import types.proto. (if this is needed, it should be relatively simple to add, simply pass more than one file to the embedded protoc call)
This commit is contained in:
parent
dc024ad2d2
commit
c4420e41d3
336
protob/pb2py
Executable file
336
protob/pb2py
Executable file
@ -0,0 +1,336 @@
|
||||
#!/usr/bin/env python3
|
||||
# Converts Google's protobuf python definitions of TREZOR wire messages
|
||||
# to plain-python objects as used in TREZOR Core and python-trezor
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
|
||||
from google.protobuf import descriptor_pb2
|
||||
|
||||
ProtoField = namedtuple(
|
||||
"ProtoField", "name, number, proto_type, py_type, repeated, required, orig"
|
||||
)
|
||||
|
||||
AUTO_HEADER = "# Automatically generated by pb2py\n"
|
||||
|
||||
# fmt: off
|
||||
FIELD_TYPES = {
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'),
|
||||
descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'),
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def protoc(files, additional_include_dirs=()):
|
||||
"""Compile code with protoc and return the data."""
|
||||
include_dirs = set(additional_include_dirs)
|
||||
for file in files:
|
||||
dirname = os.path.dirname(file) or "."
|
||||
include_dirs.add(dirname)
|
||||
protoc_includes = ["-I" + dir for dir in include_dirs]
|
||||
|
||||
# Note that we could avoid creating temp files if protoc let us write to stdout
|
||||
# directly. this is currently only possible on Unix, by passing /dev/stdout as
|
||||
# the file name. Since there's no direct Windows equivalent, not counting
|
||||
# being creative with named pipes, special-casing this is not worth the effort.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
outfile = os.path.join(tmpdir, "DESCRIPTOR_SET")
|
||||
subprocess.check_call(
|
||||
["protoc", "--descriptor_set_out={}".format(outfile)]
|
||||
+ protoc_includes
|
||||
+ files
|
||||
)
|
||||
with open(outfile, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def strip_leader(s, prefix):
|
||||
"""Remove given prefix from underscored name."""
|
||||
leader = prefix + "_"
|
||||
if s.startswith(leader):
|
||||
return s[len(leader) :]
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def import_statement_from_path(path):
|
||||
# separate leading dots
|
||||
dot_prefix = ""
|
||||
while path.startswith("."):
|
||||
dot_prefix += "."
|
||||
path = path[1:]
|
||||
|
||||
# split on remaining dots
|
||||
split_path = path.rsplit(".", maxsplit=1)
|
||||
leader, import_name = split_path[:-1], split_path[-1]
|
||||
|
||||
if leader:
|
||||
from_part = dot_prefix + leader
|
||||
elif dot_prefix:
|
||||
from_part = dot_prefix
|
||||
else:
|
||||
from_part = ""
|
||||
|
||||
if from_part:
|
||||
return "from {} import {}".format(from_part, import_name)
|
||||
else:
|
||||
return "import {}".format(import_name)
|
||||
|
||||
|
||||
class Descriptor:
|
||||
def __init__(self, data, message_type="MessageType", import_path="protobuf"):
|
||||
self.descriptor = descriptor_pb2.FileDescriptorSet()
|
||||
self.descriptor.ParseFromString(data)
|
||||
|
||||
self.files = self.descriptor.file
|
||||
|
||||
logging.debug("found {} files".format(len(self.files)))
|
||||
|
||||
# find messages and enums
|
||||
self.messages = []
|
||||
self.enums = []
|
||||
for file in self.files:
|
||||
self.messages += file.message_type
|
||||
self.enums += file.enum_type
|
||||
|
||||
if not self.messages and not self.enums:
|
||||
raise RuntimeError("No messages and no enums found.")
|
||||
|
||||
self.message_types = self.find_message_types(message_type)
|
||||
self.protobuf_import = import_statement_from_path(import_path)
|
||||
|
||||
self.out_dir = None
|
||||
|
||||
def find_message_types(self, message_type):
|
||||
message_types = {}
|
||||
try:
|
||||
message_type_enum = next(
|
||||
enum for enum in self.enums if enum.name == message_type
|
||||
)
|
||||
for value in message_type_enum.value:
|
||||
name = strip_leader(value.name, message_type)
|
||||
message_types[name] = value.number
|
||||
|
||||
except StopIteration:
|
||||
# No message type found. Oh well.
|
||||
logging.warning(
|
||||
"Message IDs not found under '{}'".format(args.message_type)
|
||||
)
|
||||
|
||||
return message_types
|
||||
|
||||
def parse_field(self, field):
|
||||
repeated = field.label == field.LABEL_REPEATED
|
||||
required = field.label == field.LABEL_REQUIRED
|
||||
if field.type == field.TYPE_MESSAGE:
|
||||
# ignore package path
|
||||
type_name = field.type_name.rsplit(".")[-1]
|
||||
proto_type = py_type = type_name
|
||||
else:
|
||||
try:
|
||||
proto_type, py_type = FIELD_TYPES[field.type]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Unknown field type {} for field {}".format(field.type, field.name)
|
||||
) from None
|
||||
|
||||
if repeated:
|
||||
py_type = "List[{}]".format(py_type)
|
||||
|
||||
return ProtoField(
|
||||
name=field.name,
|
||||
number=field.number,
|
||||
proto_type=proto_type,
|
||||
py_type=py_type,
|
||||
repeated=repeated,
|
||||
required=required,
|
||||
orig=field,
|
||||
)
|
||||
|
||||
def create_message_import(self, name):
|
||||
return "from .{0} import {0}".format(name)
|
||||
|
||||
def process_message_imports(self, fields):
|
||||
imports = set(
|
||||
field.proto_type
|
||||
for field in fields
|
||||
if field.orig.type == field.orig.TYPE_MESSAGE
|
||||
)
|
||||
|
||||
for name in sorted(imports):
|
||||
yield self.create_message_import(name)
|
||||
|
||||
def create_init_method(self, fields):
|
||||
# please keep the yields aligned
|
||||
# fmt: off
|
||||
... # https://github.com/ambv/black/issues/385
|
||||
yield " def __init__("
|
||||
yield " self,"
|
||||
for field in fields:
|
||||
yield " {}: {} = None,".format(field.name, field.py_type)
|
||||
yield " ) -> None:"
|
||||
|
||||
for field in fields:
|
||||
if field.repeated:
|
||||
yield " self.{0} = {0} if {0} is not None else []".format(field.name)
|
||||
else:
|
||||
yield " self.{0} = {0}".format(field.name)
|
||||
# fmt: on
|
||||
|
||||
def process_message(self, message):
|
||||
logging.debug("Processing message {}".format(message.name))
|
||||
msg_id = self.message_types.get(message.name)
|
||||
|
||||
# "from .. import protobuf as p"
|
||||
yield self.protobuf_import + " as p"
|
||||
|
||||
fields = [self.parse_field(field) for field in message.field]
|
||||
|
||||
if any(field.repeated for field in fields):
|
||||
yield "if __debug__:"
|
||||
yield " try:"
|
||||
yield " from typing import List"
|
||||
yield " except ImportError:"
|
||||
yield " List = None # type: ignore"
|
||||
|
||||
yield from self.process_message_imports(fields)
|
||||
|
||||
yield ""
|
||||
yield ""
|
||||
yield "class {}(p.MessageType):".format(message.name)
|
||||
|
||||
if msg_id is not None:
|
||||
yield " MESSAGE_WIRE_TYPE = {}".format(msg_id)
|
||||
|
||||
if fields:
|
||||
yield " FIELDS = {"
|
||||
for field in fields:
|
||||
comments = []
|
||||
if field.required:
|
||||
comments.append("required")
|
||||
if field.orig.HasField("default_value"):
|
||||
comments.append("default={}".format(field.orig.default_value))
|
||||
|
||||
if comments:
|
||||
comment = " # " + " ".join(comments)
|
||||
else:
|
||||
comment = ""
|
||||
|
||||
if field.repeated:
|
||||
flags = "p.FLAG_REPEATED"
|
||||
else:
|
||||
flags = "0"
|
||||
|
||||
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
|
||||
num=field.number,
|
||||
name=field.name,
|
||||
type=field.proto_type,
|
||||
flags=flags,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
yield " }"
|
||||
yield ""
|
||||
yield from self.create_init_method(fields)
|
||||
|
||||
if not fields and not msg_id:
|
||||
yield " pass"
|
||||
|
||||
def process_enum(self, enum):
|
||||
logging.debug("Processing enum {}".format(enum.name))
|
||||
|
||||
for value in enum.value:
|
||||
# Remove type name from the beginning of the constant
|
||||
# For example "PinMatrixRequestType_Current" -> "Current"
|
||||
enum_prefix = enum.name
|
||||
name = value.name
|
||||
name = strip_leader(name, enum_prefix)
|
||||
|
||||
# If type ends with *Type, but constant use type name without *Type, remove it too :)
|
||||
# For example "ButtonRequestType & ButtonRequest_Other" => "Other"
|
||||
if enum_prefix.endswith("Type"):
|
||||
enum_prefix, _ = enum_prefix.rsplit("Type", 1)
|
||||
name = strip_leader(name, enum_prefix)
|
||||
|
||||
yield "{} = {}".format(name, value.number)
|
||||
|
||||
def process_messages(self, messages):
|
||||
for message in sorted(messages, key=lambda m: m.name):
|
||||
self.write_to_file(message.name, self.process_message(message))
|
||||
|
||||
def process_enums(self, enums):
|
||||
for enum in sorted(enums, key=lambda e: e.name):
|
||||
self.write_to_file(enum.name, self.process_enum(enum))
|
||||
|
||||
def write_to_file(self, name, out):
|
||||
# Write generated sourcecode to given file
|
||||
logging.debug("Writing file {}.py".format(name))
|
||||
with open(os.path.join(self.out_dir, name + ".py"), "w") as f:
|
||||
f.write(AUTO_HEADER)
|
||||
f.write("# fmt: off\n")
|
||||
for line in out:
|
||||
f.write(line + "\n")
|
||||
|
||||
def write_init_py(self):
|
||||
filename = os.path.join(self.out_dir, "__init__.py")
|
||||
with open(filename, "w") as init_py:
|
||||
init_py.write(AUTO_HEADER)
|
||||
init_py.write("# fmt: off\n\n")
|
||||
for message in sorted(self.messages, key=lambda m: m.name):
|
||||
init_py.write(self.create_message_import(message.name) + "\n")
|
||||
for enum in sorted(self.enums, key=lambda m: m.name):
|
||||
init_py.write("from . import {}\n".format(enum.name))
|
||||
|
||||
def write_classes(self, out_dir, init_py=True):
|
||||
self.out_dir = out_dir
|
||||
self.process_messages(self.messages)
|
||||
self.process_enums(self.enums)
|
||||
if init_py:
|
||||
self.write_init_py()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# fmt: off
|
||||
parser.add_argument("proto", nargs="+", help="Protobuf definition files")
|
||||
parser.add_argument("-o", "--out-dir", help="Directory for generated source code")
|
||||
parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module")
|
||||
parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules")
|
||||
parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs")
|
||||
parser.add_argument("--protobuf-default-include", default="/usr/include", help="Location of protobuf's default .proto files.")
|
||||
# fmt: on
|
||||
args = parser.parse_args()
|
||||
|
||||
descriptor_proto = protoc(args.proto, (args.protobuf_default_include,))
|
||||
descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
descriptor.write_classes(tmpdir, not args.no_init_py)
|
||||
|
||||
for filename in os.listdir(args.out_dir):
|
||||
pathname = os.path.join(args.out_dir, filename)
|
||||
try:
|
||||
with open(pathname, "r") as f:
|
||||
if next(f, None) == AUTO_HEADER:
|
||||
os.unlink(pathname)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for filename in os.listdir(tmpdir):
|
||||
src = os.path.join(tmpdir, filename)
|
||||
shutil.copy(src, args.out_dir)
|
Loading…
Reference in New Issue
Block a user