1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 19:50:54 +00:00
trezor-firmware/common/protob/pb2py

490 lines
18 KiB
Python
Executable File

#!/usr/bin/env python3
# Converts Google's protobuf python definitions of Trezor wire messages
# to plain-python objects as used in Trezor Core and python-trezor
import argparse
import importlib
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from collections import namedtuple, defaultdict
import attr
import construct as c
from google.protobuf import descriptor_pb2
AUTO_HEADER = "# Automatically generated by pb2py\n"
# fmt: off
FIELD_TYPES = {
descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'),
# descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'),
descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'),
descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'),
descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'),
}
# fmt: on
ListOfSimpleValues = c.GreedyRange(
c.Struct(
"key" / c.VarInt,
"value" / c.VarInt,
)
)
def parse_protobuf_simple(data):
"""Micro-parse protobuf-encoded data.
Assume every field is of type 0 (varint), and parse to a dict of fieldnum: value.
"""
return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)}
PROTOC = shutil.which("protoc")
if not PROTOC:
print("protoc command not found")
sys.exit(1)
PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC))
PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include")
@attr.s
class ProtoField:
name = attr.ib()
number = attr.ib()
orig = attr.ib()
repeated = attr.ib()
required = attr.ib()
experimental = attr.ib()
type_name = attr.ib()
proto_type = attr.ib()
py_type = attr.ib()
default_value = attr.ib()
@property
def optional(self):
return not self.required and not self.repeated
@classmethod
def from_field(cls, descriptor, field):
repeated = field.label == field.LABEL_REPEATED
required = field.label == field.LABEL_REQUIRED
experimental = bool(descriptor._get_extension(field, "experimental"))
# ignore package path
type_name = field.type_name.rsplit(".")[-1]
if field.type == field.TYPE_MESSAGE:
proto_type = py_type = type_name
elif field.type == field.TYPE_ENUM:
value_dict = descriptor.enum_types[type_name]
valuestr = ", ".join(str(v) for v in value_dict.values())
proto_type = 'p.EnumType("{}", ({}))'.format(type_name, valuestr)
py_type = "EnumType" + type_name
else:
try:
proto_type, py_type = FIELD_TYPES[field.type]
except KeyError:
raise ValueError(
"Unknown field type {} for field {}".format(field.type, field.name)
) from None
if not field.HasField("default_value"):
default_value = None
elif field.type == field.TYPE_ENUM:
default_value = str(descriptor.enum_types[type_name][field.default_value])
elif field.type == field.TYPE_STRING:
default_value = f'"{field.default_value}"'
elif field.type == field.TYPE_BYTES:
default_value = f'b"{field.default_value}"'
elif field.type == field.TYPE_BOOL:
default_value = "True" if field.default_value == "true" else "False"
else:
default_value = field.default_value
return cls(
name=field.name,
number=field.number,
orig=field,
repeated=repeated,
required=required,
experimental=experimental,
type_name=type_name,
proto_type=proto_type,
py_type=py_type,
default_value=default_value,
)
def protoc(files, additional_includes=()):
"""Compile code with protoc and return the data."""
include_dirs = set()
include_dirs.add(PROTOC_INCLUDE)
include_dirs.update(additional_includes)
for file in files:
dirname = os.path.dirname(file) or "."
include_dirs.add(dirname)
protoc_includes = ["-I" + dir for dir in include_dirs if dir]
# Note that we could avoid creating temp files if protoc let us write to stdout
# directly. this is currently only possible on Unix, by passing /dev/stdout as
# the file name. Since there's no direct Windows equivalent, not counting
# being creative with named pipes, special-casing this is not worth the effort.
with tempfile.TemporaryDirectory() as tmpdir:
outfile = os.path.join(tmpdir, "DESCRIPTOR_SET")
subprocess.check_call(
[PROTOC, "--descriptor_set_out={}".format(outfile)]
+ protoc_includes
+ files
)
with open(outfile, "rb") as f:
return f.read()
def strip_leader(s, prefix):
"""Remove given prefix from underscored name."""
leader = prefix + "_"
if s.startswith(leader):
return s[len(leader) :]
else:
return s
def import_statement_from_path(path):
# separate leading dots
dot_prefix = ""
while path.startswith("."):
dot_prefix += "."
path = path[1:]
# split on remaining dots
split_path = path.rsplit(".", maxsplit=1)
leader, import_name = split_path[:-1], split_path[-1]
if leader:
from_part = dot_prefix + leader
elif dot_prefix:
from_part = dot_prefix
else:
from_part = ""
if from_part:
return "from {} import {}".format(from_part, import_name)
else:
return "import {}".format(import_name)
class Descriptor:
def __init__(self, data, message_type="MessageType", import_path="protobuf"):
self.descriptor = descriptor_pb2.FileDescriptorSet()
self.descriptor.ParseFromString(data)
self.files = self.descriptor.file
logging.debug("found {} files".format(len(self.files)))
# find messages and enums
self.messages = []
self.enums = []
self.enum_types = defaultdict(dict)
self.extensions = {}
for file in self.files:
self.messages += file.message_type
self.enums += file.enum_type
for message in file.message_type:
self._nested_types_from_message(message)
for extension in file.extension:
self.extensions[extension.name] = extension.number
if not self.messages and not self.enums:
raise RuntimeError("No messages and no enums found.")
self.message_types = self.find_message_types(message_type)
self.protobuf_import = import_statement_from_path(import_path)
self.out_dir = None
def _get_extension(self, something, extension_name, default=None):
# There doesn't seem to be a sane way to access extensions on a descriptor
# via the google.protobuf API.
# We do have access to descriptors of the extensions...
extension_num = self.extensions[extension_name]
# ...and the "options" descriptor _does_ include the extension data. But while
# the API provides access to unknown fields, it hides the extensions.
# What we do is re-encode the options descriptor...
options_bytes = something.options.SerializeToString()
# ...and re-parse it as a dict of uvarints...
simple_values = parse_protobuf_simple(options_bytes)
# ...and extract the value corresponding to the extension we care about.
return simple_values.get(extension_num, default)
def _nested_types_from_message(self, message):
self.messages += message.nested_type
self.enums += message.enum_type
for nested in message.nested_type:
self._nested_types_from_message(nested)
def find_message_types(self, message_type):
message_types = {}
try:
message_type_enum = next(e for e in self.enums if e.name == message_type)
for value in message_type_enum.value:
name = strip_leader(value.name, message_type)
message_types[name] = value.number
except StopIteration:
# No message type found. Oh well.
logging.warning(
"Message IDs not found under '{}'".format(args.message_type)
)
return message_types
def create_message_import(self, name):
return "from .{0} import {0}".format(name)
def process_subtype_imports(self, fields):
imports = set(
field.proto_type
for field in fields
if field.orig.type == field.orig.TYPE_MESSAGE
)
if len(imports) > 0:
yield "" # make isort happy
for name in sorted(imports):
yield self.create_message_import(name)
def create_init_method(self, fields):
required_fields = [f for f in fields if f.required]
repeated_fields = [f for f in fields if f.repeated]
optional_fields = [f for f in fields if f.optional]
# please keep the yields aligned
# fmt: off
yield " def __init__("
yield " self,"
yield " *,"
for field in required_fields:
yield f" {field.name}: {field.py_type},"
for field in repeated_fields:
yield f" {field.name}: List[{field.py_type}] = None,"
for field in optional_fields:
yield f" {field.name}: {field.py_type} = {field.default_value},"
yield " ) -> None:"
for field in repeated_fields:
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
for field in required_fields + optional_fields:
yield f" self.{field.name} = {field.name}"
# fmt: on
def create_fields_method(self, fields):
# fmt: off
yield " @classmethod"
yield " def get_fields(cls) -> Dict:"
yield " return {"
for field in fields:
comments = []
if field.default_value is not None:
comments.append(f"default={field.orig.default_value}")
if comments:
comment = " # " + " ".join(comments)
else:
comment = ""
if field.repeated:
if field.default_value is not None:
raise ValueError("Repeated fields can't have default values.")
if field.experimental:
raise ValueError("Repeated experimental fields are currently not supported.")
flags = "p.FLAG_REPEATED"
elif field.required:
if field.default_value is not None:
raise ValueError("Required fields can't have default values.")
if field.experimental:
raise ValueError("Required fields can't be experimental.")
flags = "p.FLAG_REQUIRED"
elif field.experimental:
if field.default_value is not None:
raise ValueError("Experimental fields can't have default values.")
flags = "p.FLAG_EXPERIMENTAL"
else:
flags = field.default_value
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
num=field.number,
name=field.name,
type=field.proto_type,
flags=flags,
comment=comment,
)
yield " }"
# fmt: on
def process_message(self, message, include_deprecated=False):
logging.debug("Processing message {}".format(message.name))
msg_id = self._get_extension(message, "wire_type")
if msg_id is None:
msg_id = self.message_types.get(message.name)
unstable = self._get_extension(message, "unstable")
# "from .. import protobuf as p"
yield self.protobuf_import + " as p"
fields = [ProtoField.from_field(self, field) for field in message.field]
if not include_deprecated:
fields = [field for field in fields if not field.orig.options.deprecated]
yield from self.process_subtype_imports(fields)
yield ""
yield "if __debug__:"
yield " try:"
yield " from typing import Dict, List # noqa: F401"
yield " from typing_extensions import Literal # noqa: F401"
all_enums = [field for field in fields if field.type_name in self.enum_types]
for field in all_enums:
allowed_values = self.enum_types[field.type_name].values()
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
yield " {} = Literal[{}]".format(field.py_type, valuestr)
yield " except ImportError:"
yield " pass"
yield ""
yield ""
yield "class {}(p.MessageType):".format(message.name)
if msg_id is not None:
yield " MESSAGE_WIRE_TYPE = {}".format(msg_id)
if unstable is not None:
yield " UNSTABLE = True"
if fields:
yield ""
yield from self.create_init_method(fields)
yield ""
yield from self.create_fields_method(fields)
if not fields and not msg_id:
yield " pass"
def process_enum(self, enum):
logging.debug("Processing enum {}".format(enum.name))
# file header
yield "if __debug__:"
yield " try:"
yield " from typing_extensions import Literal # noqa: F401"
yield " except ImportError:"
yield " pass"
yield ""
for value in enum.value:
# Remove type name from the beginning of the constant
# For example "PinMatrixRequestType_Current" -> "Current"
enum_prefix = enum.name
name = value.name
name = strip_leader(name, enum_prefix)
# If type ends with *Type, but constant use type name without *Type, remove it too :)
# For example "ButtonRequestType & ButtonRequest_Other" => "Other"
if enum_prefix.endswith("Type"):
enum_prefix, _ = enum_prefix.rsplit("Type", 1)
name = strip_leader(name, enum_prefix)
self.enum_types[enum.name][value.name] = value.number
yield f"{name}: Literal[{value.number}] = {value.number}"
def process_messages(self, messages, include_deprecated=False):
for message in sorted(messages, key=lambda m: m.name):
self.write_to_file(
message.name, self.process_message(message, include_deprecated)
)
def process_enums(self, enums):
for enum in sorted(enums, key=lambda e: e.name):
self.write_to_file(enum.name, self.process_enum(enum))
def write_to_file(self, name, out):
# Write generated sourcecode to given file
logging.debug("Writing file {}.py".format(name))
with open(os.path.join(self.out_dir, name + ".py"), "w") as f:
f.write(AUTO_HEADER)
f.write("# fmt: off\n")
for line in out:
f.write(line + "\n")
def write_init_py(self):
filename = os.path.join(self.out_dir, "__init__.py")
with open(filename, "w") as init_py:
init_py.write(AUTO_HEADER)
init_py.write("# fmt: off\n\n")
for message in sorted(self.messages, key=lambda m: m.name):
init_py.write(self.create_message_import(message.name) + "\n")
for enum in sorted(self.enums, key=lambda m: m.name):
init_py.write("from . import {}\n".format(enum.name))
def write_classes(self, out_dir, init_py=True, include_deprecated=False):
self.out_dir = out_dir
self.process_enums(self.enums)
self.process_messages(self.messages, include_deprecated)
if init_py:
self.write_init_py()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("proto", nargs="+", help="Protobuf definition files")
parser.add_argument("-o", "--out-dir", help="Directory for generated source code")
parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module")
parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules")
parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs")
parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path")
parser.add_argument("-v", "--verbose", action="store_true", help="Print debug messages")
parser.add_argument("-d", "--include-deprecated", action="store_true", help="Include deprecated fields")
# fmt: on
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),)
descriptor_proto = protoc(args.proto, protoc_includes)
descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module)
with tempfile.TemporaryDirectory() as tmpdir:
descriptor.write_classes(tmpdir, not args.no_init_py, args.include_deprecated)
for filename in os.listdir(args.out_dir):
pathname = os.path.join(args.out_dir, filename)
try:
with open(pathname, "r") as f:
if next(f, None) == AUTO_HEADER:
os.unlink(pathname)
except Exception:
pass
for filename in os.listdir(tmpdir):
src = os.path.join(tmpdir, filename)
shutil.copy(src, args.out_dir)