#!/usr/bin/env python3 # Converts Google's protobuf python definitions of TREZOR wire messages # to plain-python objects as used in TREZOR Core and python-trezor import argparse import importlib import logging import os import shutil import subprocess import sys import tempfile from collections import namedtuple from google.protobuf import descriptor_pb2 ProtoField = namedtuple( "ProtoField", "name, number, proto_type, py_type, repeated, required, orig" ) AUTO_HEADER = "# Automatically generated by pb2py\n" # fmt: off FIELD_TYPES = { descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'), descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'), descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'), } # fmt: on PROTOC = shutil.which("protoc") if not PROTOC: print("protoc command not found") sys.exit(1) PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC)) PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include") def protoc(files, additional_includes=()): """Compile code with protoc and return the data.""" include_dirs = set() include_dirs.add(PROTOC_INCLUDE) include_dirs.update(additional_includes) for file in files: dirname = os.path.dirname(file) or "." include_dirs.add(dirname) protoc_includes = ["-I" + dir for dir in include_dirs if dir] # Note that we could avoid creating temp files if protoc let us write to stdout # directly. this is currently only possible on Unix, by passing /dev/stdout as # the file name. Since there's no direct Windows equivalent, not counting # being creative with named pipes, special-casing this is not worth the effort. with tempfile.TemporaryDirectory() as tmpdir: outfile = os.path.join(tmpdir, "DESCRIPTOR_SET") subprocess.check_call( [PROTOC, "--descriptor_set_out={}".format(outfile)] + protoc_includes + files ) with open(outfile, "rb") as f: return f.read() def strip_leader(s, prefix): """Remove given prefix from underscored name.""" leader = prefix + "_" if s.startswith(leader): return s[len(leader) :] else: return s def import_statement_from_path(path): # separate leading dots dot_prefix = "" while path.startswith("."): dot_prefix += "." path = path[1:] # split on remaining dots split_path = path.rsplit(".", maxsplit=1) leader, import_name = split_path[:-1], split_path[-1] if leader: from_part = dot_prefix + leader elif dot_prefix: from_part = dot_prefix else: from_part = "" if from_part: return "from {} import {}".format(from_part, import_name) else: return "import {}".format(import_name) class Descriptor: def __init__(self, data, message_type="MessageType", import_path="protobuf"): self.descriptor = descriptor_pb2.FileDescriptorSet() self.descriptor.ParseFromString(data) self.files = self.descriptor.file logging.debug("found {} files".format(len(self.files))) # find messages and enums self.messages = [] self.enums = [] for file in self.files: self.messages += file.message_type self.enums += file.enum_type for message in file.message_type: self._nested_types_from_message(message) if not self.messages and not self.enums: raise RuntimeError("No messages and no enums found.") self.message_types = self.find_message_types(message_type) self.protobuf_import = import_statement_from_path(import_path) self.out_dir = None def _nested_types_from_message(self, message): self.messages += message.nested_type self.enums += message.enum_type for nested in message.nested_type: self._nested_types_from_message(nested) def find_message_types(self, message_type): message_types = {} try: message_type_enum = next( enum for enum in self.enums if enum.name == message_type ) for value in message_type_enum.value: name = strip_leader(value.name, message_type) message_types[name] = value.number except StopIteration: # No message type found. Oh well. logging.warning( "Message IDs not found under '{}'".format(args.message_type) ) return message_types def parse_field(self, field): repeated = field.label == field.LABEL_REPEATED required = field.label == field.LABEL_REQUIRED if field.type == field.TYPE_MESSAGE: # ignore package path type_name = field.type_name.rsplit(".")[-1] proto_type = py_type = type_name else: try: proto_type, py_type = FIELD_TYPES[field.type] except KeyError: raise ValueError( "Unknown field type {} for field {}".format(field.type, field.name) ) from None if repeated: py_type = "List[{}]".format(py_type) return ProtoField( name=field.name, number=field.number, proto_type=proto_type, py_type=py_type, repeated=repeated, required=required, orig=field, ) def create_message_import(self, name): return "from .{0} import {0}".format(name) def process_message_imports(self, fields): imports = set( field.proto_type for field in fields if field.orig.type == field.orig.TYPE_MESSAGE ) for name in sorted(imports): yield self.create_message_import(name) def create_init_method(self, fields): # please keep the yields aligned # fmt: off ... # https://github.com/ambv/black/issues/385 yield " def __init__(" yield " self," for field in fields: yield " {}: {} = None,".format(field.name, field.py_type) yield " ) -> None:" for field in fields: if field.repeated: yield " self.{0} = {0} if {0} is not None else []".format(field.name) else: yield " self.{0} = {0}".format(field.name) # fmt: on def process_message(self, message): logging.debug("Processing message {}".format(message.name)) msg_id = self.message_types.get(message.name) # "from .. import protobuf as p" yield self.protobuf_import + " as p" yield "" fields = [self.parse_field(field) for field in message.field] if any(field.repeated for field in fields): yield "if __debug__:" yield " try:" yield " from typing import List" yield " except ImportError:" yield " List = None # type: ignore" yield from self.process_message_imports(fields) yield "" yield "" yield "class {}(p.MessageType):".format(message.name) if msg_id is not None: yield " MESSAGE_WIRE_TYPE = {}".format(msg_id) if fields: yield " FIELDS = {" for field in fields: comments = [] if field.required: comments.append("required") if field.orig.HasField("default_value"): comments.append("default={}".format(field.orig.default_value)) if comments: comment = " # " + " ".join(comments) else: comment = "" if field.repeated: flags = "p.FLAG_REPEATED" else: flags = "0" yield " {num}: ('{name}', {type}, {flags}),{comment}".format( num=field.number, name=field.name, type=field.proto_type, flags=flags, comment=comment, ) yield " }" yield "" yield from self.create_init_method(fields) if not fields and not msg_id: yield " pass" def process_enum(self, enum): logging.debug("Processing enum {}".format(enum.name)) for value in enum.value: # Remove type name from the beginning of the constant # For example "PinMatrixRequestType_Current" -> "Current" enum_prefix = enum.name name = value.name name = strip_leader(name, enum_prefix) # If type ends with *Type, but constant use type name without *Type, remove it too :) # For example "ButtonRequestType & ButtonRequest_Other" => "Other" if enum_prefix.endswith("Type"): enum_prefix, _ = enum_prefix.rsplit("Type", 1) name = strip_leader(name, enum_prefix) yield "{} = {}".format(name, value.number) def process_messages(self, messages): for message in sorted(messages, key=lambda m: m.name): self.write_to_file(message.name, self.process_message(message)) def process_enums(self, enums): for enum in sorted(enums, key=lambda e: e.name): self.write_to_file(enum.name, self.process_enum(enum)) def write_to_file(self, name, out): # Write generated sourcecode to given file logging.debug("Writing file {}.py".format(name)) with open(os.path.join(self.out_dir, name + ".py"), "w") as f: f.write(AUTO_HEADER) f.write("# fmt: off\n") for line in out: f.write(line + "\n") def write_init_py(self): filename = os.path.join(self.out_dir, "__init__.py") with open(filename, "w") as init_py: init_py.write(AUTO_HEADER) init_py.write("# fmt: off\n\n") for message in sorted(self.messages, key=lambda m: m.name): init_py.write(self.create_message_import(message.name) + "\n") for enum in sorted(self.enums, key=lambda m: m.name): init_py.write("from . import {}\n".format(enum.name)) def write_classes(self, out_dir, init_py=True): self.out_dir = out_dir self.process_messages(self.messages) self.process_enums(self.enums) if init_py: self.write_init_py() if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser() # fmt: off parser.add_argument("proto", nargs="+", help="Protobuf definition files") parser.add_argument("-o", "--out-dir", help="Directory for generated source code") parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module") parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules") parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs") parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path") # fmt: on args = parser.parse_args() protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),) descriptor_proto = protoc(args.proto, protoc_includes) descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module) with tempfile.TemporaryDirectory() as tmpdir: descriptor.write_classes(tmpdir, not args.no_init_py) for filename in os.listdir(args.out_dir): pathname = os.path.join(args.out_dir, filename) try: with open(pathname, "r") as f: if next(f, None) == AUTO_HEADER: os.unlink(pathname) except Exception: pass for filename in os.listdir(tmpdir): src = os.path.join(tmpdir, filename) shutil.copy(src, args.out_dir)