1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-10-13 11:29:11 +00:00
trezor-firmware/protob/pb2py
matejcik c4420e41d3 protob: add a smart pb2py builder
Now we don't need build_protobuf anymore and this is usable
by both core and python-trezor
(as well as generating custom protobufs from other sources)

We still need protoc, unfortunately, but pb2py now calls it
by itself. (little more robustly; instead of generated
python classes, it uses the FileDescriptorSet output which
is parsable by a built-in protobuf class)

To support the script, messages.proto and types.proto must
set a common package. Also there is currently no support for
compiling more than one proto file, we depend on the fact
that messages.proto import types.proto.
(if this is needed, it should be relatively simple to add,
simply pass more than one file to the embedded protoc call)
2018-07-02 18:37:56 +02:00

337 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
# Converts Google's protobuf python definitions of TREZOR wire messages
# to plain-python objects as used in TREZOR Core and python-trezor
import argparse
import importlib
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from collections import namedtuple
from google.protobuf import descriptor_pb2
ProtoField = namedtuple(
"ProtoField", "name, number, proto_type, py_type, repeated, required, orig"
)
AUTO_HEADER = "# Automatically generated by pb2py\n"
# fmt: off
FIELD_TYPES = {
descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'),
descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'),
descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'),
}
# fmt: on
def protoc(files, additional_include_dirs=()):
"""Compile code with protoc and return the data."""
include_dirs = set(additional_include_dirs)
for file in files:
dirname = os.path.dirname(file) or "."
include_dirs.add(dirname)
protoc_includes = ["-I" + dir for dir in include_dirs]
# Note that we could avoid creating temp files if protoc let us write to stdout
# directly. this is currently only possible on Unix, by passing /dev/stdout as
# the file name. Since there's no direct Windows equivalent, not counting
# being creative with named pipes, special-casing this is not worth the effort.
with tempfile.TemporaryDirectory() as tmpdir:
outfile = os.path.join(tmpdir, "DESCRIPTOR_SET")
subprocess.check_call(
["protoc", "--descriptor_set_out={}".format(outfile)]
+ protoc_includes
+ files
)
with open(outfile, "rb") as f:
return f.read()
def strip_leader(s, prefix):
"""Remove given prefix from underscored name."""
leader = prefix + "_"
if s.startswith(leader):
return s[len(leader) :]
else:
return s
def import_statement_from_path(path):
# separate leading dots
dot_prefix = ""
while path.startswith("."):
dot_prefix += "."
path = path[1:]
# split on remaining dots
split_path = path.rsplit(".", maxsplit=1)
leader, import_name = split_path[:-1], split_path[-1]
if leader:
from_part = dot_prefix + leader
elif dot_prefix:
from_part = dot_prefix
else:
from_part = ""
if from_part:
return "from {} import {}".format(from_part, import_name)
else:
return "import {}".format(import_name)
class Descriptor:
def __init__(self, data, message_type="MessageType", import_path="protobuf"):
self.descriptor = descriptor_pb2.FileDescriptorSet()
self.descriptor.ParseFromString(data)
self.files = self.descriptor.file
logging.debug("found {} files".format(len(self.files)))
# find messages and enums
self.messages = []
self.enums = []
for file in self.files:
self.messages += file.message_type
self.enums += file.enum_type
if not self.messages and not self.enums:
raise RuntimeError("No messages and no enums found.")
self.message_types = self.find_message_types(message_type)
self.protobuf_import = import_statement_from_path(import_path)
self.out_dir = None
def find_message_types(self, message_type):
message_types = {}
try:
message_type_enum = next(
enum for enum in self.enums if enum.name == message_type
)
for value in message_type_enum.value:
name = strip_leader(value.name, message_type)
message_types[name] = value.number
except StopIteration:
# No message type found. Oh well.
logging.warning(
"Message IDs not found under '{}'".format(args.message_type)
)
return message_types
def parse_field(self, field):
repeated = field.label == field.LABEL_REPEATED
required = field.label == field.LABEL_REQUIRED
if field.type == field.TYPE_MESSAGE:
# ignore package path
type_name = field.type_name.rsplit(".")[-1]
proto_type = py_type = type_name
else:
try:
proto_type, py_type = FIELD_TYPES[field.type]
except KeyError:
raise ValueError(
"Unknown field type {} for field {}".format(field.type, field.name)
) from None
if repeated:
py_type = "List[{}]".format(py_type)
return ProtoField(
name=field.name,
number=field.number,
proto_type=proto_type,
py_type=py_type,
repeated=repeated,
required=required,
orig=field,
)
def create_message_import(self, name):
return "from .{0} import {0}".format(name)
def process_message_imports(self, fields):
imports = set(
field.proto_type
for field in fields
if field.orig.type == field.orig.TYPE_MESSAGE
)
for name in sorted(imports):
yield self.create_message_import(name)
def create_init_method(self, fields):
# please keep the yields aligned
# fmt: off
... # https://github.com/ambv/black/issues/385
yield " def __init__("
yield " self,"
for field in fields:
yield " {}: {} = None,".format(field.name, field.py_type)
yield " ) -> None:"
for field in fields:
if field.repeated:
yield " self.{0} = {0} if {0} is not None else []".format(field.name)
else:
yield " self.{0} = {0}".format(field.name)
# fmt: on
def process_message(self, message):
logging.debug("Processing message {}".format(message.name))
msg_id = self.message_types.get(message.name)
# "from .. import protobuf as p"
yield self.protobuf_import + " as p"
fields = [self.parse_field(field) for field in message.field]
if any(field.repeated for field in fields):
yield "if __debug__:"
yield " try:"
yield " from typing import List"
yield " except ImportError:"
yield " List = None # type: ignore"
yield from self.process_message_imports(fields)
yield ""
yield ""
yield "class {}(p.MessageType):".format(message.name)
if msg_id is not None:
yield " MESSAGE_WIRE_TYPE = {}".format(msg_id)
if fields:
yield " FIELDS = {"
for field in fields:
comments = []
if field.required:
comments.append("required")
if field.orig.HasField("default_value"):
comments.append("default={}".format(field.orig.default_value))
if comments:
comment = " # " + " ".join(comments)
else:
comment = ""
if field.repeated:
flags = "p.FLAG_REPEATED"
else:
flags = "0"
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
num=field.number,
name=field.name,
type=field.proto_type,
flags=flags,
comment=comment,
)
yield " }"
yield ""
yield from self.create_init_method(fields)
if not fields and not msg_id:
yield " pass"
def process_enum(self, enum):
logging.debug("Processing enum {}".format(enum.name))
for value in enum.value:
# Remove type name from the beginning of the constant
# For example "PinMatrixRequestType_Current" -> "Current"
enum_prefix = enum.name
name = value.name
name = strip_leader(name, enum_prefix)
# If type ends with *Type, but constant use type name without *Type, remove it too :)
# For example "ButtonRequestType & ButtonRequest_Other" => "Other"
if enum_prefix.endswith("Type"):
enum_prefix, _ = enum_prefix.rsplit("Type", 1)
name = strip_leader(name, enum_prefix)
yield "{} = {}".format(name, value.number)
def process_messages(self, messages):
for message in sorted(messages, key=lambda m: m.name):
self.write_to_file(message.name, self.process_message(message))
def process_enums(self, enums):
for enum in sorted(enums, key=lambda e: e.name):
self.write_to_file(enum.name, self.process_enum(enum))
def write_to_file(self, name, out):
# Write generated sourcecode to given file
logging.debug("Writing file {}.py".format(name))
with open(os.path.join(self.out_dir, name + ".py"), "w") as f:
f.write(AUTO_HEADER)
f.write("# fmt: off\n")
for line in out:
f.write(line + "\n")
def write_init_py(self):
filename = os.path.join(self.out_dir, "__init__.py")
with open(filename, "w") as init_py:
init_py.write(AUTO_HEADER)
init_py.write("# fmt: off\n\n")
for message in sorted(self.messages, key=lambda m: m.name):
init_py.write(self.create_message_import(message.name) + "\n")
for enum in sorted(self.enums, key=lambda m: m.name):
init_py.write("from . import {}\n".format(enum.name))
def write_classes(self, out_dir, init_py=True):
self.out_dir = out_dir
self.process_messages(self.messages)
self.process_enums(self.enums)
if init_py:
self.write_init_py()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("proto", nargs="+", help="Protobuf definition files")
parser.add_argument("-o", "--out-dir", help="Directory for generated source code")
parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module")
parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules")
parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs")
parser.add_argument("--protobuf-default-include", default="/usr/include", help="Location of protobuf's default .proto files.")
# fmt: on
args = parser.parse_args()
descriptor_proto = protoc(args.proto, (args.protobuf_default_include,))
descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module)
with tempfile.TemporaryDirectory() as tmpdir:
descriptor.write_classes(tmpdir, not args.no_init_py)
for filename in os.listdir(args.out_dir):
pathname = os.path.join(args.out_dir, filename)
try:
with open(pathname, "r") as f:
if next(f, None) == AUTO_HEADER:
os.unlink(pathname)
except Exception:
pass
for filename in os.listdir(tmpdir):
src = os.path.join(tmpdir, filename)
shutil.copy(src, args.out_dir)