#!/usr/bin/env python3
# Converts Google's protobuf python definitions of Trezor wire messages
# to plain-python objects as used in Trezor Core and python-trezor

import argparse
import importlib
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from collections import namedtuple, defaultdict

import attr
import construct as c

from google.protobuf import descriptor_pb2


AUTO_HEADER = "# Automatically generated by pb2py\n"

# fmt: off
FIELD_TYPES = {
    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64:  ('p.UVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32:  ('p.UVarintType', 'int'),
#    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM:    ('p.UVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32:  ('p.SVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64:  ('p.SVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_STRING:  ('p.UnicodeType', 'str'),
    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL:    ('p.BoolType', 'bool'),
    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES:   ('p.BytesType', 'bytes'),
}
# fmt: on

ListOfSimpleValues = c.GreedyRange(
    c.Struct(
        "key" / c.VarInt,
        "value" / c.VarInt,
    )
)


def parse_protobuf_simple(data):
    """Micro-parse protobuf-encoded data.

    Assume every field is of type 0 (varint), and parse to a dict of fieldnum: value.
    """
    return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)}


PROTOC = shutil.which("protoc")
if not PROTOC:
    print("protoc command not found")
    sys.exit(1)

PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC))
PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include")


@attr.s
class ProtoField:
    name = attr.ib()
    number = attr.ib()
    orig = attr.ib()
    repeated = attr.ib()
    required = attr.ib()
    experimental = attr.ib()
    type_name = attr.ib()
    proto_type = attr.ib()
    py_type = attr.ib()
    default_value = attr.ib()

    @property
    def optional(self):
        return not self.required and not self.repeated

    @classmethod
    def from_field(cls, descriptor, field):
        repeated = field.label == field.LABEL_REPEATED
        required = field.label == field.LABEL_REQUIRED
        experimental = bool(descriptor._get_extension(field, "experimental"))
        # ignore package path
        type_name = field.type_name.rsplit(".")[-1]

        if field.type == field.TYPE_MESSAGE:
            proto_type = py_type = type_name
        elif field.type == field.TYPE_ENUM:
            value_dict = descriptor.enum_types[type_name]
            valuestr = ", ".join(str(v) for v in value_dict.values())
            proto_type = 'p.EnumType("{}", ({},))'.format(type_name, valuestr)
            py_type = "EnumType" + type_name
        else:
            try:
                proto_type, py_type = FIELD_TYPES[field.type]
            except KeyError:
                raise ValueError(
                    "Unknown field type {} for field {}".format(field.type, field.name)
                ) from None

        if not field.HasField("default_value"):
            default_value = None
        elif field.type == field.TYPE_ENUM:
            default_value = str(descriptor.enum_types[type_name][field.default_value])
        elif field.type == field.TYPE_STRING:
            default_value = f'"{field.default_value}"'
        elif field.type == field.TYPE_BYTES:
            default_value = f'b"{field.default_value}"'
        elif field.type == field.TYPE_BOOL:
            default_value = "True" if field.default_value == "true" else "False"
        else:
            default_value = field.default_value

        return cls(
            name=field.name,
            number=field.number,
            orig=field,
            repeated=repeated,
            required=required,
            experimental=experimental,
            type_name=type_name,
            proto_type=proto_type,
            py_type=py_type,
            default_value=default_value,
        )


def protoc(files, additional_includes=()):
    """Compile code with protoc and return the data."""
    include_dirs = set()
    include_dirs.add(PROTOC_INCLUDE)
    include_dirs.update(additional_includes)

    for file in files:
        dirname = os.path.dirname(file) or "."
        include_dirs.add(dirname)
    protoc_includes = ["-I" + dir for dir in include_dirs if dir]

    # Note that we could avoid creating temp files if protoc let us write to stdout
    # directly. this is currently only possible on Unix, by passing /dev/stdout as
    # the file name. Since there's no direct Windows equivalent, not counting
    # being creative with named pipes, special-casing this is not worth the effort.
    with tempfile.TemporaryDirectory() as tmpdir:
        outfile = os.path.join(tmpdir, "DESCRIPTOR_SET")
        subprocess.check_call(
            [PROTOC, "--descriptor_set_out={}".format(outfile)]
            + protoc_includes
            + files
        )
        with open(outfile, "rb") as f:
            return f.read()


def strip_leader(s, prefix):
    """Remove given prefix from underscored name."""
    leader = prefix + "_"
    if s.startswith(leader):
        return s[len(leader) :]
    else:
        return s


def import_statement_from_path(path):
    # separate leading dots
    dot_prefix = ""
    while path.startswith("."):
        dot_prefix += "."
        path = path[1:]

    # split on remaining dots
    split_path = path.rsplit(".", maxsplit=1)
    leader, import_name = split_path[:-1], split_path[-1]

    if leader:
        from_part = dot_prefix + leader
    elif dot_prefix:
        from_part = dot_prefix
    else:
        from_part = ""

    if from_part:
        return "from {} import {}".format(from_part, import_name)
    else:
        return "import {}".format(import_name)


class Descriptor:
    def __init__(self, data, message_type="MessageType", import_path="protobuf"):
        self.descriptor = descriptor_pb2.FileDescriptorSet()
        self.descriptor.ParseFromString(data)

        self.files = self.descriptor.file

        logging.debug("found {} files".format(len(self.files)))

        # find messages and enums
        self.messages = []
        self.enums = []
        self.enum_types = defaultdict(dict)

        self.extensions = {}

        for file in self.files:
            self.messages += file.message_type
            self.enums += file.enum_type
            for message in file.message_type:
                self._nested_types_from_message(message)
            for extension in file.extension:
                self.extensions[extension.name] = extension.number

        if not self.messages and not self.enums:
            raise RuntimeError("No messages and no enums found.")

        self.message_types = self.find_message_types(message_type)
        self.protobuf_import = import_statement_from_path(import_path)

        self.out_dir = None

    def _get_extension(self, something, extension_name, default=None):
        # There doesn't seem to be a sane way to access extensions on a descriptor
        # via the google.protobuf API.
        # We do have access to descriptors of the extensions...
        extension_num = self.extensions[extension_name]
        # ...and the "options" descriptor _does_ include the extension data. But while
        # the API provides access to unknown fields, it hides the extensions.
        # What we do is re-encode the options descriptor...
        options_bytes = something.options.SerializeToString()
        # ...and re-parse it as a dict of uvarints...
        simple_values = parse_protobuf_simple(options_bytes)
        # ...and extract the value corresponding to the extension we care about.
        return simple_values.get(extension_num, default)

    def _nested_types_from_message(self, message):
        self.messages += message.nested_type
        self.enums += message.enum_type
        for nested in message.nested_type:
            self._nested_types_from_message(nested)

    def find_message_types(self, message_type):
        message_types = {}
        try:
            message_type_enum = next(e for e in self.enums if e.name == message_type)
            for value in message_type_enum.value:
                name = strip_leader(value.name, message_type)
                message_types[name] = value.number

        except StopIteration:
            # No message type found. Oh well.
            logging.warning(
                "Message IDs not found under '{}'".format(args.message_type)
            )

        return message_types

    def create_message_import(self, name):
        return "from .{0} import {0}".format(name)

    def process_subtype_imports(self, fields):
        imports = set(
            field.proto_type
            for field in fields
            if field.orig.type == field.orig.TYPE_MESSAGE
        )

        if len(imports) > 0:
            yield ""  # make isort happy
        for name in sorted(imports):
            yield self.create_message_import(name)

    def create_init_method(self, fields):
        required_fields = [f for f in fields if f.required]
        repeated_fields = [f for f in fields if f.repeated]
        optional_fields = [f for f in fields if f.optional]
        # please keep the yields aligned
        # fmt: off
        yield      "    def __init__("
        yield      "        self,"
        yield      "        *,"
        for field in required_fields:
            yield f"        {field.name}: {field.py_type},"
        for field in repeated_fields:
            yield f"        {field.name}: Optional[List[{field.py_type}]] = None,"
        for field in optional_fields:
            if field.default_value is None:
                yield f"        {field.name}: Optional[{field.py_type}] = None,"
            else:
                yield f"        {field.name}: {field.py_type} = {field.default_value},"
        yield          "    ) -> None:"

        for field in repeated_fields:
            yield f"        self.{field.name} = {field.name} if {field.name} is not None else []"
        for field in required_fields + optional_fields:
            yield f"        self.{field.name} = {field.name}"
        # fmt: on

    def create_fields_method(self, fields):
        # fmt: off
        yield "    @classmethod"
        yield "    def get_fields(cls) -> Dict:"
        yield "        return {"
        for field in fields:
            comments = []
            if field.default_value is not None:
                comments.append(f"default={field.orig.default_value}")

            if comments:
                comment = "  # " + " ".join(comments)
            else:
                comment = ""

            if field.repeated:
                if field.default_value is not None:
                    raise ValueError("Repeated fields can't have default values.")
                if field.experimental:
                    raise ValueError("Repeated experimental fields are currently not supported.")
                flags = "p.FLAG_REPEATED"
            elif field.required:
                if field.default_value is not None:
                    raise ValueError("Required fields can't have default values.")
                if field.experimental:
                    raise ValueError("Required fields can't be experimental.")
                flags = "p.FLAG_REQUIRED"
            elif field.experimental:
                if field.default_value is not None:
                    raise ValueError("Experimental fields can't have default values.")
                flags = "p.FLAG_EXPERIMENTAL"
            else:
                flags = field.default_value

            yield "            {num}: ('{name}', {type}, {flags}),{comment}".format(
                num=field.number,
                name=field.name,
                type=field.proto_type,
                flags=flags,
                comment=comment,
            )

        yield "        }"
        # fmt: on

    def process_message(self, message, include_deprecated=False):
        logging.debug("Processing message {}".format(message.name))

        msg_id = self._get_extension(message, "wire_type")
        if msg_id is None:
            msg_id = self.message_types.get(message.name)

        unstable = self._get_extension(message, "unstable")

        # "from .. import protobuf as p"
        yield self.protobuf_import + " as p"

        fields = [ProtoField.from_field(self, field) for field in message.field]
        if not include_deprecated:
            fields = [field for field in fields if not field.orig.options.deprecated]

        yield from self.process_subtype_imports(fields)

        yield ""
        yield "if __debug__:"
        yield "    try:"
        yield "        from typing import Dict, List, Optional  # noqa: F401"
        yield "        from typing_extensions import Literal  # noqa: F401"

        all_enums = [field for field in fields if field.type_name in self.enum_types]
        for field in all_enums:
            allowed_values = self.enum_types[field.type_name].values()
            valuestr = ", ".join(str(v) for v in sorted(allowed_values))
            yield "        {} = Literal[{}]".format(field.py_type, valuestr)

        yield "    except ImportError:"
        yield "        pass"

        yield ""
        yield ""
        yield "class {}(p.MessageType):".format(message.name)

        if msg_id is not None:
            yield "    MESSAGE_WIRE_TYPE = {}".format(msg_id)

        if unstable is not None:
            yield "    UNSTABLE = True"

        if fields:
            yield ""
            yield from self.create_init_method(fields)
            yield ""
            yield from self.create_fields_method(fields)

        if not fields and not msg_id:
            yield "    pass"

    def process_enum(self, enum):
        logging.debug("Processing enum {}".format(enum.name))

        # file header
        yield "if __debug__:"
        yield "    try:"
        yield "        from typing_extensions import Literal  # noqa: F401"
        yield "    except ImportError:"
        yield "        pass"
        yield ""

        for value in enum.value:
            # Remove type name from the beginning of the constant
            # For example "PinMatrixRequestType_Current" -> "Current"
            enum_prefix = enum.name
            name = value.name
            name = strip_leader(name, enum_prefix)

            # If type ends with *Type, but constant use type name without *Type, remove it too :)
            # For example "ButtonRequestType & ButtonRequest_Other" => "Other"
            if enum_prefix.endswith("Type"):
                enum_prefix, _ = enum_prefix.rsplit("Type", 1)
                name = strip_leader(name, enum_prefix)

            self.enum_types[enum.name][value.name] = value.number
            yield f"{name}: Literal[{value.number}] = {value.number}"

    def process_messages(self, messages, include_deprecated=False):
        for message in sorted(messages, key=lambda m: m.name):
            self.write_to_file(
                message.name, self.process_message(message, include_deprecated)
            )

    def process_enums(self, enums):
        for enum in sorted(enums, key=lambda e: e.name):
            self.write_to_file(enum.name, self.process_enum(enum))

    def write_to_file(self, name, out):
        # Write generated sourcecode to given file
        logging.debug("Writing file {}.py".format(name))
        with open(os.path.join(self.out_dir, name + ".py"), "w") as f:
            f.write(AUTO_HEADER)
            f.write("# fmt: off\n")
            f.write("# isort:skip_file\n")
            for line in out:
                f.write(line + "\n")

    def write_init_py(self):
        filename = os.path.join(self.out_dir, "__init__.py")
        with open(filename, "w") as init_py:
            init_py.write(AUTO_HEADER)
            init_py.write("# fmt: off\n")
            init_py.write("# isort:skip_file\n\n")
            for message in sorted(self.messages, key=lambda m: m.name):
                init_py.write(self.create_message_import(message.name) + "\n")
            for enum in sorted(self.enums, key=lambda m: m.name):
                init_py.write("from . import {}\n".format(enum.name))

    def write_classes(self, out_dir, init_py=True, include_deprecated=False):
        self.out_dir = out_dir
        self.process_enums(self.enums)
        self.process_messages(self.messages, include_deprecated)
        if init_py:
            self.write_init_py()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # fmt: off
    parser.add_argument("proto", nargs="+", help="Protobuf definition files")
    parser.add_argument("-o", "--out-dir", help="Directory for generated source code")
    parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module")
    parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules")
    parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs")
    parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path")
    parser.add_argument("-v", "--verbose", action="store_true", help="Print debug messages")
    parser.add_argument("-d", "--include-deprecated", action="store_true", help="Include deprecated fields")
    # fmt: on
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),)
    descriptor_proto = protoc(args.proto, protoc_includes)
    descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module)

    with tempfile.TemporaryDirectory() as tmpdir:
        descriptor.write_classes(tmpdir, not args.no_init_py, args.include_deprecated)

        for filename in os.listdir(args.out_dir):
            pathname = os.path.join(args.out_dir, filename)
            try:
                with open(pathname, "r") as f:
                    if next(f, None) == AUTO_HEADER:
                        os.unlink(pathname)
            except Exception:
                pass

        for filename in os.listdir(tmpdir):
            src = os.path.join(tmpdir, filename)
            shutil.copy(src, args.out_dir)