diff --git a/protob/pb2py b/protob/pb2py new file mode 100755 index 0000000000..37f839c3dd --- /dev/null +++ b/protob/pb2py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Converts Google's protobuf python definitions of TREZOR wire messages +# to plain-python objects as used in TREZOR Core and python-trezor + +import argparse +import importlib +import logging +import os +import shutil +import subprocess +import sys +import tempfile +from collections import namedtuple + +from google.protobuf import descriptor_pb2 + +ProtoField = namedtuple( + "ProtoField", "name, number, proto_type, py_type, repeated, required, orig" +) + +AUTO_HEADER = "# Automatically generated by pb2py\n" + +# fmt: off +FIELD_TYPES = { + descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'), + descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'), + descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'), + descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'), + descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'), + descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'), + descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'), + descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'), +} +# fmt: on + + +def protoc(files, additional_include_dirs=()): + """Compile code with protoc and return the data.""" + include_dirs = set(additional_include_dirs) + for file in files: + dirname = os.path.dirname(file) or "." + include_dirs.add(dirname) + protoc_includes = ["-I" + dir for dir in include_dirs] + + # Note that we could avoid creating temp files if protoc let us write to stdout + # directly. this is currently only possible on Unix, by passing /dev/stdout as + # the file name. Since there's no direct Windows equivalent, not counting + # being creative with named pipes, special-casing this is not worth the effort. + with tempfile.TemporaryDirectory() as tmpdir: + outfile = os.path.join(tmpdir, "DESCRIPTOR_SET") + subprocess.check_call( + ["protoc", "--descriptor_set_out={}".format(outfile)] + + protoc_includes + + files + ) + with open(outfile, "rb") as f: + return f.read() + + +def strip_leader(s, prefix): + """Remove given prefix from underscored name.""" + leader = prefix + "_" + if s.startswith(leader): + return s[len(leader) :] + else: + return s + + +def import_statement_from_path(path): + # separate leading dots + dot_prefix = "" + while path.startswith("."): + dot_prefix += "." + path = path[1:] + + # split on remaining dots + split_path = path.rsplit(".", maxsplit=1) + leader, import_name = split_path[:-1], split_path[-1] + + if leader: + from_part = dot_prefix + leader + elif dot_prefix: + from_part = dot_prefix + else: + from_part = "" + + if from_part: + return "from {} import {}".format(from_part, import_name) + else: + return "import {}".format(import_name) + + +class Descriptor: + def __init__(self, data, message_type="MessageType", import_path="protobuf"): + self.descriptor = descriptor_pb2.FileDescriptorSet() + self.descriptor.ParseFromString(data) + + self.files = self.descriptor.file + + logging.debug("found {} files".format(len(self.files))) + + # find messages and enums + self.messages = [] + self.enums = [] + for file in self.files: + self.messages += file.message_type + self.enums += file.enum_type + + if not self.messages and not self.enums: + raise RuntimeError("No messages and no enums found.") + + self.message_types = self.find_message_types(message_type) + self.protobuf_import = import_statement_from_path(import_path) + + self.out_dir = None + + def find_message_types(self, message_type): + message_types = {} + try: + message_type_enum = next( + enum for enum in self.enums if enum.name == message_type + ) + for value in message_type_enum.value: + name = strip_leader(value.name, message_type) + message_types[name] = value.number + + except StopIteration: + # No message type found. Oh well. + logging.warning( + "Message IDs not found under '{}'".format(args.message_type) + ) + + return message_types + + def parse_field(self, field): + repeated = field.label == field.LABEL_REPEATED + required = field.label == field.LABEL_REQUIRED + if field.type == field.TYPE_MESSAGE: + # ignore package path + type_name = field.type_name.rsplit(".")[-1] + proto_type = py_type = type_name + else: + try: + proto_type, py_type = FIELD_TYPES[field.type] + except KeyError: + raise ValueError( + "Unknown field type {} for field {}".format(field.type, field.name) + ) from None + + if repeated: + py_type = "List[{}]".format(py_type) + + return ProtoField( + name=field.name, + number=field.number, + proto_type=proto_type, + py_type=py_type, + repeated=repeated, + required=required, + orig=field, + ) + + def create_message_import(self, name): + return "from .{0} import {0}".format(name) + + def process_message_imports(self, fields): + imports = set( + field.proto_type + for field in fields + if field.orig.type == field.orig.TYPE_MESSAGE + ) + + for name in sorted(imports): + yield self.create_message_import(name) + + def create_init_method(self, fields): + # please keep the yields aligned + # fmt: off + ... # https://github.com/ambv/black/issues/385 + yield " def __init__(" + yield " self," + for field in fields: + yield " {}: {} = None,".format(field.name, field.py_type) + yield " ) -> None:" + + for field in fields: + if field.repeated: + yield " self.{0} = {0} if {0} is not None else []".format(field.name) + else: + yield " self.{0} = {0}".format(field.name) + # fmt: on + + def process_message(self, message): + logging.debug("Processing message {}".format(message.name)) + msg_id = self.message_types.get(message.name) + + # "from .. import protobuf as p" + yield self.protobuf_import + " as p" + + fields = [self.parse_field(field) for field in message.field] + + if any(field.repeated for field in fields): + yield "if __debug__:" + yield " try:" + yield " from typing import List" + yield " except ImportError:" + yield " List = None # type: ignore" + + yield from self.process_message_imports(fields) + + yield "" + yield "" + yield "class {}(p.MessageType):".format(message.name) + + if msg_id is not None: + yield " MESSAGE_WIRE_TYPE = {}".format(msg_id) + + if fields: + yield " FIELDS = {" + for field in fields: + comments = [] + if field.required: + comments.append("required") + if field.orig.HasField("default_value"): + comments.append("default={}".format(field.orig.default_value)) + + if comments: + comment = " # " + " ".join(comments) + else: + comment = "" + + if field.repeated: + flags = "p.FLAG_REPEATED" + else: + flags = "0" + + yield " {num}: ('{name}', {type}, {flags}),{comment}".format( + num=field.number, + name=field.name, + type=field.proto_type, + flags=flags, + comment=comment, + ) + + yield " }" + yield "" + yield from self.create_init_method(fields) + + if not fields and not msg_id: + yield " pass" + + def process_enum(self, enum): + logging.debug("Processing enum {}".format(enum.name)) + + for value in enum.value: + # Remove type name from the beginning of the constant + # For example "PinMatrixRequestType_Current" -> "Current" + enum_prefix = enum.name + name = value.name + name = strip_leader(name, enum_prefix) + + # If type ends with *Type, but constant use type name without *Type, remove it too :) + # For example "ButtonRequestType & ButtonRequest_Other" => "Other" + if enum_prefix.endswith("Type"): + enum_prefix, _ = enum_prefix.rsplit("Type", 1) + name = strip_leader(name, enum_prefix) + + yield "{} = {}".format(name, value.number) + + def process_messages(self, messages): + for message in sorted(messages, key=lambda m: m.name): + self.write_to_file(message.name, self.process_message(message)) + + def process_enums(self, enums): + for enum in sorted(enums, key=lambda e: e.name): + self.write_to_file(enum.name, self.process_enum(enum)) + + def write_to_file(self, name, out): + # Write generated sourcecode to given file + logging.debug("Writing file {}.py".format(name)) + with open(os.path.join(self.out_dir, name + ".py"), "w") as f: + f.write(AUTO_HEADER) + f.write("# fmt: off\n") + for line in out: + f.write(line + "\n") + + def write_init_py(self): + filename = os.path.join(self.out_dir, "__init__.py") + with open(filename, "w") as init_py: + init_py.write(AUTO_HEADER) + init_py.write("# fmt: off\n\n") + for message in sorted(self.messages, key=lambda m: m.name): + init_py.write(self.create_message_import(message.name) + "\n") + for enum in sorted(self.enums, key=lambda m: m.name): + init_py.write("from . import {}\n".format(enum.name)) + + def write_classes(self, out_dir, init_py=True): + self.out_dir = out_dir + self.process_messages(self.messages) + self.process_enums(self.enums) + if init_py: + self.write_init_py() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + # fmt: off + parser.add_argument("proto", nargs="+", help="Protobuf definition files") + parser.add_argument("-o", "--out-dir", help="Directory for generated source code") + parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module") + parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules") + parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs") + parser.add_argument("--protobuf-default-include", default="/usr/include", help="Location of protobuf's default .proto files.") + # fmt: on + args = parser.parse_args() + + descriptor_proto = protoc(args.proto, (args.protobuf_default_include,)) + descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module) + + with tempfile.TemporaryDirectory() as tmpdir: + descriptor.write_classes(tmpdir, not args.no_init_py) + + for filename in os.listdir(args.out_dir): + pathname = os.path.join(args.out_dir, filename) + try: + with open(pathname, "r") as f: + if next(f, None) == AUTO_HEADER: + os.unlink(pathname) + except Exception: + pass + + for filename in os.listdir(tmpdir): + src = os.path.join(tmpdir, filename) + shutil.copy(src, args.out_dir)