#!/usr/bin/env python3 # Converts Google's protobuf python definitions of Trezor wire messages # to plain-python objects as used in Trezor Core and python-trezor import argparse import importlib import logging import os import shutil import subprocess import sys import tempfile from collections import namedtuple, defaultdict import attr from google.protobuf import descriptor_pb2 AUTO_HEADER = "# Automatically generated by pb2py\n" # fmt: off FIELD_TYPES = { descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'), # descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'), descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'), descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'), descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'), } # fmt: on PROTOC = shutil.which("protoc") if not PROTOC: print("protoc command not found") sys.exit(1) PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC)) PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include") @attr.s class ProtoField: name = attr.ib() number = attr.ib() orig = attr.ib() repeated = attr.ib() required = attr.ib() type_name = attr.ib() proto_type = attr.ib() py_type = attr.ib() default_value = attr.ib() @property def optional(self): return not self.required and not self.repeated @classmethod def from_field(cls, descriptor, field): repeated = field.label == field.LABEL_REPEATED required = field.label == field.LABEL_REQUIRED # ignore package path type_name = field.type_name.rsplit(".")[-1] if field.type == field.TYPE_MESSAGE: proto_type = py_type = type_name elif field.type == field.TYPE_ENUM: value_dict = descriptor.enum_types[type_name] valuestr = ", ".join(str(v) for v in value_dict.values()) proto_type = 'p.EnumType("{}", ({}))'.format(type_name, valuestr) py_type = "EnumType" + type_name else: try: proto_type, py_type = FIELD_TYPES[field.type] except KeyError: raise ValueError( "Unknown field type {} for field {}".format(field.type, field.name) ) from None if not field.HasField("default_value"): default_value = None elif field.type == field.TYPE_ENUM: default_value = str(descriptor.enum_types[type_name][field.default_value]) elif field.type == field.TYPE_STRING: default_value = f'"{field.default_value}"' elif field.type == field.TYPE_BYTES: default_value = f'b"{field.default_value}"' elif field.type == field.TYPE_BOOL: default_value = "True" if field.default_value == "true" else "False" else: default_value = field.default_value return cls( name=field.name, number=field.number, orig=field, repeated=repeated, required=required, type_name=type_name, proto_type=proto_type, py_type=py_type, default_value=default_value, ) def protoc(files, additional_includes=()): """Compile code with protoc and return the data.""" include_dirs = set() include_dirs.add(PROTOC_INCLUDE) include_dirs.update(additional_includes) for file in files: dirname = os.path.dirname(file) or "." include_dirs.add(dirname) protoc_includes = ["-I" + dir for dir in include_dirs if dir] # Note that we could avoid creating temp files if protoc let us write to stdout # directly. this is currently only possible on Unix, by passing /dev/stdout as # the file name. Since there's no direct Windows equivalent, not counting # being creative with named pipes, special-casing this is not worth the effort. with tempfile.TemporaryDirectory() as tmpdir: outfile = os.path.join(tmpdir, "DESCRIPTOR_SET") subprocess.check_call( [PROTOC, "--descriptor_set_out={}".format(outfile)] + protoc_includes + files ) with open(outfile, "rb") as f: return f.read() def strip_leader(s, prefix): """Remove given prefix from underscored name.""" leader = prefix + "_" if s.startswith(leader): return s[len(leader) :] else: return s def import_statement_from_path(path): # separate leading dots dot_prefix = "" while path.startswith("."): dot_prefix += "." path = path[1:] # split on remaining dots split_path = path.rsplit(".", maxsplit=1) leader, import_name = split_path[:-1], split_path[-1] if leader: from_part = dot_prefix + leader elif dot_prefix: from_part = dot_prefix else: from_part = "" if from_part: return "from {} import {}".format(from_part, import_name) else: return "import {}".format(import_name) class Descriptor: def __init__(self, data, message_type="MessageType", import_path="protobuf"): self.descriptor = descriptor_pb2.FileDescriptorSet() self.descriptor.ParseFromString(data) self.files = self.descriptor.file logging.debug("found {} files".format(len(self.files))) # find messages and enums self.messages = [] self.enums = [] self.enum_types = defaultdict(dict) for file in self.files: self.messages += file.message_type self.enums += file.enum_type for message in file.message_type: self._nested_types_from_message(message) if not self.messages and not self.enums: raise RuntimeError("No messages and no enums found.") self.message_types = self.find_message_types(message_type) self.protobuf_import = import_statement_from_path(import_path) self.out_dir = None def _nested_types_from_message(self, message): self.messages += message.nested_type self.enums += message.enum_type for nested in message.nested_type: self._nested_types_from_message(nested) def find_message_types(self, message_type): message_types = {} try: message_type_enum = next(e for e in self.enums if e.name == message_type) for value in message_type_enum.value: name = strip_leader(value.name, message_type) message_types[name] = value.number except StopIteration: # No message type found. Oh well. logging.warning( "Message IDs not found under '{}'".format(args.message_type) ) return message_types def create_message_import(self, name): return "from .{0} import {0}".format(name) def process_subtype_imports(self, fields): imports = set( field.proto_type for field in fields if field.orig.type == field.orig.TYPE_MESSAGE ) if len(imports) > 0: yield "" # make isort happy for name in sorted(imports): yield self.create_message_import(name) def create_init_method(self, fields): required_fields = [f for f in fields if f.required] repeated_fields = [f for f in fields if f.repeated] optional_fields = [f for f in fields if f.optional] # please keep the yields aligned # fmt: off yield " def __init__(" yield " self," yield " *," for field in required_fields: yield f" {field.name}: {field.py_type}," for field in repeated_fields: yield f" {field.name}: List[{field.py_type}] = None," for field in optional_fields: yield f" {field.name}: {field.py_type} = {field.default_value}," yield " ) -> None:" for field in repeated_fields: yield f" self.{field.name} = {field.name} if {field.name} is not None else []" for field in required_fields + optional_fields: yield f" self.{field.name} = {field.name}" # fmt: on def create_fields_method(self, fields): # fmt: off yield " @classmethod" yield " def get_fields(cls) -> Dict:" yield " return {" for field in fields: comments = [] if field.default_value is not None: comments.append(f"default={field.orig.default_value}") if comments: comment = " # " + " ".join(comments) else: comment = "" if field.repeated: flags = "p.FLAG_REPEATED" elif field.required: flags = "p.FLAG_REQUIRED" else: flags = field.default_value yield " {num}: ('{name}', {type}, {flags}),{comment}".format( num=field.number, name=field.name, type=field.proto_type, flags=flags, comment=comment, ) yield " }" # fmt: on def process_message(self, message, include_deprecated=False): logging.debug("Processing message {}".format(message.name)) msg_id = self.message_types.get(message.name) # "from .. import protobuf as p" yield self.protobuf_import + " as p" fields = [ProtoField.from_field(self, field) for field in message.field] if not include_deprecated: fields = [field for field in fields if not field.orig.options.deprecated] yield from self.process_subtype_imports(fields) yield "" yield "if __debug__:" yield " try:" yield " from typing import Dict, List # noqa: F401" yield " from typing_extensions import Literal # noqa: F401" all_enums = [field for field in fields if field.type_name in self.enum_types] for field in all_enums: allowed_values = self.enum_types[field.type_name].values() valuestr = ", ".join(str(v) for v in sorted(allowed_values)) yield " {} = Literal[{}]".format(field.py_type, valuestr) yield " except ImportError:" yield " pass" yield "" yield "" yield "class {}(p.MessageType):".format(message.name) if msg_id is not None: yield " MESSAGE_WIRE_TYPE = {}".format(msg_id) if fields: yield "" yield from self.create_init_method(fields) yield "" yield from self.create_fields_method(fields) if not fields and not msg_id: yield " pass" def process_enum(self, enum): logging.debug("Processing enum {}".format(enum.name)) # file header yield "if False:" yield " from typing_extensions import Literal" yield "" for value in enum.value: # Remove type name from the beginning of the constant # For example "PinMatrixRequestType_Current" -> "Current" enum_prefix = enum.name name = value.name name = strip_leader(name, enum_prefix) # If type ends with *Type, but constant use type name without *Type, remove it too :) # For example "ButtonRequestType & ButtonRequest_Other" => "Other" if enum_prefix.endswith("Type"): enum_prefix, _ = enum_prefix.rsplit("Type", 1) name = strip_leader(name, enum_prefix) self.enum_types[enum.name][value.name] = value.number yield f"{name} = {value.number} # type: Literal[{value.number}]" def process_messages(self, messages, include_deprecated=False): for message in sorted(messages, key=lambda m: m.name): self.write_to_file( message.name, self.process_message(message, include_deprecated) ) def process_enums(self, enums): for enum in sorted(enums, key=lambda e: e.name): self.write_to_file(enum.name, self.process_enum(enum)) def write_to_file(self, name, out): # Write generated sourcecode to given file logging.debug("Writing file {}.py".format(name)) with open(os.path.join(self.out_dir, name + ".py"), "w") as f: f.write(AUTO_HEADER) f.write("# fmt: off\n") for line in out: f.write(line + "\n") def write_init_py(self): filename = os.path.join(self.out_dir, "__init__.py") with open(filename, "w") as init_py: init_py.write(AUTO_HEADER) init_py.write("# fmt: off\n\n") for message in sorted(self.messages, key=lambda m: m.name): init_py.write(self.create_message_import(message.name) + "\n") for enum in sorted(self.enums, key=lambda m: m.name): init_py.write("from . import {}\n".format(enum.name)) def write_classes(self, out_dir, init_py=True, include_deprecated=False): self.out_dir = out_dir self.process_enums(self.enums) self.process_messages(self.messages, include_deprecated) if init_py: self.write_init_py() if __name__ == "__main__": parser = argparse.ArgumentParser() # fmt: off parser.add_argument("proto", nargs="+", help="Protobuf definition files") parser.add_argument("-o", "--out-dir", help="Directory for generated source code") parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module") parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules") parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs") parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path") parser.add_argument("-v", "--verbose", action="store_true", help="Print debug messages") parser.add_argument("-d", "--include-deprecated", action="store_true", help="Include deprecated fields") # fmt: on args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),) descriptor_proto = protoc(args.proto, protoc_includes) descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module) with tempfile.TemporaryDirectory() as tmpdir: descriptor.write_classes(tmpdir, not args.no_init_py, args.include_deprecated) for filename in os.listdir(args.out_dir): pathname = os.path.join(args.out_dir, filename) try: with open(pathname, "r") as f: if next(f, None) == AUTO_HEADER: os.unlink(pathname) except Exception: pass for filename in os.listdir(tmpdir): src = os.path.join(tmpdir, filename) shutil.copy(src, args.out_dir)