mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
487 lines
18 KiB
Python
Executable File
487 lines
18 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Converts Google's protobuf python definitions of Trezor wire messages
|
|
# to plain-python objects as used in Trezor Core and python-trezor
|
|
|
|
import argparse
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from collections import namedtuple, defaultdict
|
|
|
|
import attr
|
|
import construct as c
|
|
|
|
from google.protobuf import descriptor_pb2
|
|
|
|
|
|
AUTO_HEADER = "# Automatically generated by pb2py\n"
|
|
|
|
# fmt: off
|
|
FIELD_TYPES = {
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'),
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'),
|
|
# descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'),
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'),
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'),
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'),
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'),
|
|
descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'),
|
|
}
|
|
# fmt: on
|
|
|
|
ListOfSimpleValues = c.GreedyRange(
|
|
c.Struct(
|
|
"key" / c.VarInt,
|
|
"value" / c.VarInt,
|
|
)
|
|
)
|
|
|
|
|
|
def parse_protobuf_simple(data):
|
|
"""Micro-parse protobuf-encoded data.
|
|
|
|
Assume every field is of type 0 (varint), and parse to a dict of fieldnum: value.
|
|
"""
|
|
return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)}
|
|
|
|
|
|
PROTOC = shutil.which("protoc")
|
|
if not PROTOC:
|
|
print("protoc command not found")
|
|
sys.exit(1)
|
|
|
|
PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC))
|
|
PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include")
|
|
|
|
|
|
@attr.s
|
|
class ProtoField:
|
|
name = attr.ib()
|
|
number = attr.ib()
|
|
orig = attr.ib()
|
|
repeated = attr.ib()
|
|
required = attr.ib()
|
|
experimental = attr.ib()
|
|
type_name = attr.ib()
|
|
proto_type = attr.ib()
|
|
py_type = attr.ib()
|
|
default_value = attr.ib()
|
|
|
|
@property
|
|
def optional(self):
|
|
return not self.required and not self.repeated
|
|
|
|
@classmethod
|
|
def from_field(cls, descriptor, field):
|
|
repeated = field.label == field.LABEL_REPEATED
|
|
required = field.label == field.LABEL_REQUIRED
|
|
experimental = bool(descriptor._get_extension(field, "experimental"))
|
|
# ignore package path
|
|
type_name = field.type_name.rsplit(".")[-1]
|
|
|
|
if field.type == field.TYPE_MESSAGE:
|
|
proto_type = py_type = type_name
|
|
elif field.type == field.TYPE_ENUM:
|
|
value_dict = descriptor.enum_types[type_name]
|
|
valuestr = ", ".join(str(v) for v in value_dict.values())
|
|
proto_type = 'p.EnumType("{}", ({}))'.format(type_name, valuestr)
|
|
py_type = "EnumType" + type_name
|
|
else:
|
|
try:
|
|
proto_type, py_type = FIELD_TYPES[field.type]
|
|
except KeyError:
|
|
raise ValueError(
|
|
"Unknown field type {} for field {}".format(field.type, field.name)
|
|
) from None
|
|
|
|
if not field.HasField("default_value"):
|
|
default_value = None
|
|
elif field.type == field.TYPE_ENUM:
|
|
default_value = str(descriptor.enum_types[type_name][field.default_value])
|
|
elif field.type == field.TYPE_STRING:
|
|
default_value = f'"{field.default_value}"'
|
|
elif field.type == field.TYPE_BYTES:
|
|
default_value = f'b"{field.default_value}"'
|
|
elif field.type == field.TYPE_BOOL:
|
|
default_value = "True" if field.default_value == "true" else "False"
|
|
else:
|
|
default_value = field.default_value
|
|
|
|
return cls(
|
|
name=field.name,
|
|
number=field.number,
|
|
orig=field,
|
|
repeated=repeated,
|
|
required=required,
|
|
experimental=experimental,
|
|
type_name=type_name,
|
|
proto_type=proto_type,
|
|
py_type=py_type,
|
|
default_value=default_value,
|
|
)
|
|
|
|
|
|
def protoc(files, additional_includes=()):
|
|
"""Compile code with protoc and return the data."""
|
|
include_dirs = set()
|
|
include_dirs.add(PROTOC_INCLUDE)
|
|
include_dirs.update(additional_includes)
|
|
|
|
for file in files:
|
|
dirname = os.path.dirname(file) or "."
|
|
include_dirs.add(dirname)
|
|
protoc_includes = ["-I" + dir for dir in include_dirs if dir]
|
|
|
|
# Note that we could avoid creating temp files if protoc let us write to stdout
|
|
# directly. this is currently only possible on Unix, by passing /dev/stdout as
|
|
# the file name. Since there's no direct Windows equivalent, not counting
|
|
# being creative with named pipes, special-casing this is not worth the effort.
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
outfile = os.path.join(tmpdir, "DESCRIPTOR_SET")
|
|
subprocess.check_call(
|
|
[PROTOC, "--descriptor_set_out={}".format(outfile)]
|
|
+ protoc_includes
|
|
+ files
|
|
)
|
|
with open(outfile, "rb") as f:
|
|
return f.read()
|
|
|
|
|
|
def strip_leader(s, prefix):
|
|
"""Remove given prefix from underscored name."""
|
|
leader = prefix + "_"
|
|
if s.startswith(leader):
|
|
return s[len(leader) :]
|
|
else:
|
|
return s
|
|
|
|
|
|
def import_statement_from_path(path):
|
|
# separate leading dots
|
|
dot_prefix = ""
|
|
while path.startswith("."):
|
|
dot_prefix += "."
|
|
path = path[1:]
|
|
|
|
# split on remaining dots
|
|
split_path = path.rsplit(".", maxsplit=1)
|
|
leader, import_name = split_path[:-1], split_path[-1]
|
|
|
|
if leader:
|
|
from_part = dot_prefix + leader
|
|
elif dot_prefix:
|
|
from_part = dot_prefix
|
|
else:
|
|
from_part = ""
|
|
|
|
if from_part:
|
|
return "from {} import {}".format(from_part, import_name)
|
|
else:
|
|
return "import {}".format(import_name)
|
|
|
|
|
|
class Descriptor:
|
|
def __init__(self, data, message_type="MessageType", import_path="protobuf"):
|
|
self.descriptor = descriptor_pb2.FileDescriptorSet()
|
|
self.descriptor.ParseFromString(data)
|
|
|
|
self.files = self.descriptor.file
|
|
|
|
logging.debug("found {} files".format(len(self.files)))
|
|
|
|
# find messages and enums
|
|
self.messages = []
|
|
self.enums = []
|
|
self.enum_types = defaultdict(dict)
|
|
|
|
self.extensions = {}
|
|
|
|
for file in self.files:
|
|
self.messages += file.message_type
|
|
self.enums += file.enum_type
|
|
for message in file.message_type:
|
|
self._nested_types_from_message(message)
|
|
for extension in file.extension:
|
|
self.extensions[extension.name] = extension.number
|
|
|
|
if not self.messages and not self.enums:
|
|
raise RuntimeError("No messages and no enums found.")
|
|
|
|
self.message_types = self.find_message_types(message_type)
|
|
self.protobuf_import = import_statement_from_path(import_path)
|
|
|
|
self.out_dir = None
|
|
|
|
def _get_extension(self, something, extension_name, default=None):
|
|
# There doesn't seem to be a sane way to access extensions on a descriptor
|
|
# via the google.protobuf API.
|
|
# We do have access to descriptors of the extensions...
|
|
extension_num = self.extensions[extension_name]
|
|
# ...and the "options" descriptor _does_ include the extension data. But while
|
|
# the API provides access to unknown fields, it hides the extensions.
|
|
# What we do is re-encode the options descriptor...
|
|
options_bytes = something.options.SerializeToString()
|
|
# ...and re-parse it as a dict of uvarints...
|
|
simple_values = parse_protobuf_simple(options_bytes)
|
|
# ...and extract the value corresponding to the extension we care about.
|
|
return simple_values.get(extension_num, default)
|
|
|
|
def _nested_types_from_message(self, message):
|
|
self.messages += message.nested_type
|
|
self.enums += message.enum_type
|
|
for nested in message.nested_type:
|
|
self._nested_types_from_message(nested)
|
|
|
|
def find_message_types(self, message_type):
|
|
message_types = {}
|
|
try:
|
|
message_type_enum = next(e for e in self.enums if e.name == message_type)
|
|
for value in message_type_enum.value:
|
|
name = strip_leader(value.name, message_type)
|
|
message_types[name] = value.number
|
|
|
|
except StopIteration:
|
|
# No message type found. Oh well.
|
|
logging.warning(
|
|
"Message IDs not found under '{}'".format(args.message_type)
|
|
)
|
|
|
|
return message_types
|
|
|
|
def create_message_import(self, name):
|
|
return "from .{0} import {0}".format(name)
|
|
|
|
def process_subtype_imports(self, fields):
|
|
imports = set(
|
|
field.proto_type
|
|
for field in fields
|
|
if field.orig.type == field.orig.TYPE_MESSAGE
|
|
)
|
|
|
|
if len(imports) > 0:
|
|
yield "" # make isort happy
|
|
for name in sorted(imports):
|
|
yield self.create_message_import(name)
|
|
|
|
def create_init_method(self, fields):
|
|
required_fields = [f for f in fields if f.required]
|
|
repeated_fields = [f for f in fields if f.repeated]
|
|
optional_fields = [f for f in fields if f.optional]
|
|
# please keep the yields aligned
|
|
# fmt: off
|
|
yield " def __init__("
|
|
yield " self,"
|
|
yield " *,"
|
|
for field in required_fields:
|
|
yield f" {field.name}: {field.py_type},"
|
|
for field in repeated_fields:
|
|
yield f" {field.name}: List[{field.py_type}] = None,"
|
|
for field in optional_fields:
|
|
yield f" {field.name}: {field.py_type} = {field.default_value},"
|
|
yield " ) -> None:"
|
|
|
|
for field in repeated_fields:
|
|
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
|
|
for field in required_fields + optional_fields:
|
|
yield f" self.{field.name} = {field.name}"
|
|
# fmt: on
|
|
|
|
def create_fields_method(self, fields):
|
|
# fmt: off
|
|
yield " @classmethod"
|
|
yield " def get_fields(cls) -> Dict:"
|
|
yield " return {"
|
|
for field in fields:
|
|
comments = []
|
|
if field.default_value is not None:
|
|
comments.append(f"default={field.orig.default_value}")
|
|
|
|
if comments:
|
|
comment = " # " + " ".join(comments)
|
|
else:
|
|
comment = ""
|
|
|
|
if field.repeated:
|
|
if field.default_value is not None:
|
|
raise ValueError("Repeated fields can't have default values.")
|
|
if field.experimental:
|
|
raise ValueError("Repeated experimental fields are currently not supported.")
|
|
flags = "p.FLAG_REPEATED"
|
|
elif field.required:
|
|
if field.default_value is not None:
|
|
raise ValueError("Required fields can't have default values.")
|
|
if field.experimental:
|
|
raise ValueError("Required fields can't be experimental.")
|
|
flags = "p.FLAG_REQUIRED"
|
|
elif field.experimental:
|
|
if field.default_value is not None:
|
|
raise ValueError("Experimental fields can't have default values.")
|
|
flags = "p.FLAG_EXPERIMENTAL"
|
|
else:
|
|
flags = field.default_value
|
|
|
|
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
|
|
num=field.number,
|
|
name=field.name,
|
|
type=field.proto_type,
|
|
flags=flags,
|
|
comment=comment,
|
|
)
|
|
|
|
yield " }"
|
|
# fmt: on
|
|
|
|
def process_message(self, message, include_deprecated=False):
|
|
logging.debug("Processing message {}".format(message.name))
|
|
|
|
msg_id = self._get_extension(message, "wire_type")
|
|
if msg_id is None:
|
|
msg_id = self.message_types.get(message.name)
|
|
|
|
unstable = self._get_extension(message, "unstable")
|
|
|
|
# "from .. import protobuf as p"
|
|
yield self.protobuf_import + " as p"
|
|
|
|
fields = [ProtoField.from_field(self, field) for field in message.field]
|
|
if not include_deprecated:
|
|
fields = [field for field in fields if not field.orig.options.deprecated]
|
|
|
|
yield from self.process_subtype_imports(fields)
|
|
|
|
yield ""
|
|
yield "if __debug__:"
|
|
yield " try:"
|
|
yield " from typing import Dict, List # noqa: F401"
|
|
yield " from typing_extensions import Literal # noqa: F401"
|
|
|
|
all_enums = [field for field in fields if field.type_name in self.enum_types]
|
|
for field in all_enums:
|
|
allowed_values = self.enum_types[field.type_name].values()
|
|
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
|
|
yield " {} = Literal[{}]".format(field.py_type, valuestr)
|
|
|
|
yield " except ImportError:"
|
|
yield " pass"
|
|
|
|
yield ""
|
|
yield ""
|
|
yield "class {}(p.MessageType):".format(message.name)
|
|
|
|
if msg_id is not None:
|
|
yield " MESSAGE_WIRE_TYPE = {}".format(msg_id)
|
|
|
|
if unstable is not None:
|
|
yield " UNSTABLE = True"
|
|
|
|
if fields:
|
|
yield ""
|
|
yield from self.create_init_method(fields)
|
|
yield ""
|
|
yield from self.create_fields_method(fields)
|
|
|
|
if not fields and not msg_id:
|
|
yield " pass"
|
|
|
|
def process_enum(self, enum):
|
|
logging.debug("Processing enum {}".format(enum.name))
|
|
|
|
# file header
|
|
yield "if False:"
|
|
yield " from typing_extensions import Literal"
|
|
yield ""
|
|
|
|
for value in enum.value:
|
|
# Remove type name from the beginning of the constant
|
|
# For example "PinMatrixRequestType_Current" -> "Current"
|
|
enum_prefix = enum.name
|
|
name = value.name
|
|
name = strip_leader(name, enum_prefix)
|
|
|
|
# If type ends with *Type, but constant use type name without *Type, remove it too :)
|
|
# For example "ButtonRequestType & ButtonRequest_Other" => "Other"
|
|
if enum_prefix.endswith("Type"):
|
|
enum_prefix, _ = enum_prefix.rsplit("Type", 1)
|
|
name = strip_leader(name, enum_prefix)
|
|
|
|
self.enum_types[enum.name][value.name] = value.number
|
|
yield f"{name} = {value.number} # type: Literal[{value.number}]"
|
|
|
|
def process_messages(self, messages, include_deprecated=False):
|
|
for message in sorted(messages, key=lambda m: m.name):
|
|
self.write_to_file(
|
|
message.name, self.process_message(message, include_deprecated)
|
|
)
|
|
|
|
def process_enums(self, enums):
|
|
for enum in sorted(enums, key=lambda e: e.name):
|
|
self.write_to_file(enum.name, self.process_enum(enum))
|
|
|
|
def write_to_file(self, name, out):
|
|
# Write generated sourcecode to given file
|
|
logging.debug("Writing file {}.py".format(name))
|
|
with open(os.path.join(self.out_dir, name + ".py"), "w") as f:
|
|
f.write(AUTO_HEADER)
|
|
f.write("# fmt: off\n")
|
|
for line in out:
|
|
f.write(line + "\n")
|
|
|
|
def write_init_py(self):
|
|
filename = os.path.join(self.out_dir, "__init__.py")
|
|
with open(filename, "w") as init_py:
|
|
init_py.write(AUTO_HEADER)
|
|
init_py.write("# fmt: off\n\n")
|
|
for message in sorted(self.messages, key=lambda m: m.name):
|
|
init_py.write(self.create_message_import(message.name) + "\n")
|
|
for enum in sorted(self.enums, key=lambda m: m.name):
|
|
init_py.write("from . import {}\n".format(enum.name))
|
|
|
|
def write_classes(self, out_dir, init_py=True, include_deprecated=False):
|
|
self.out_dir = out_dir
|
|
self.process_enums(self.enums)
|
|
self.process_messages(self.messages, include_deprecated)
|
|
if init_py:
|
|
self.write_init_py()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
# fmt: off
|
|
parser.add_argument("proto", nargs="+", help="Protobuf definition files")
|
|
parser.add_argument("-o", "--out-dir", help="Directory for generated source code")
|
|
parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module")
|
|
parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules")
|
|
parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs")
|
|
parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path")
|
|
parser.add_argument("-v", "--verbose", action="store_true", help="Print debug messages")
|
|
parser.add_argument("-d", "--include-deprecated", action="store_true", help="Include deprecated fields")
|
|
# fmt: on
|
|
args = parser.parse_args()
|
|
|
|
if args.verbose:
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),)
|
|
descriptor_proto = protoc(args.proto, protoc_includes)
|
|
descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
descriptor.write_classes(tmpdir, not args.no_init_py, args.include_deprecated)
|
|
|
|
for filename in os.listdir(args.out_dir):
|
|
pathname = os.path.join(args.out_dir, filename)
|
|
try:
|
|
with open(pathname, "r") as f:
|
|
if next(f, None) == AUTO_HEADER:
|
|
os.unlink(pathname)
|
|
except Exception:
|
|
pass
|
|
|
|
for filename in os.listdir(tmpdir):
|
|
src = os.path.join(tmpdir, filename)
|
|
shutil.copy(src, args.out_dir)
|