#!/usr/bin/env python3
# Converts Google's protobuf python definitions of TREZOR wire messages
# to plain-python objects as used in TREZOR Core and python-trezor

import argparse
import importlib
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from collections import namedtuple

from google.protobuf import descriptor_pb2

ProtoField = namedtuple(
    "ProtoField", "name, number, proto_type, py_type, repeated, required, orig"
)

AUTO_HEADER = "# Automatically generated by pb2py\n"

# fmt: off
FIELD_TYPES = {
    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64:  ('p.UVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32:  ('p.UVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM:    ('p.UVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32:  ('p.SVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64:  ('p.SVarintType', 'int'),
    descriptor_pb2.FieldDescriptorProto.TYPE_STRING:  ('p.UnicodeType', 'str'),
    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL:    ('p.BoolType', 'bool'),
    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES:   ('p.BytesType', 'bytes'),
}
# fmt: on

PROTOC = shutil.which("protoc")
if not PROTOC:
    print("protoc command not found")
    sys.exit(1)

PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC))
PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include")


def protoc(files, additional_includes=()):
    """Compile code with protoc and return the data."""
    include_dirs = set()
    include_dirs.add(PROTOC_INCLUDE)
    include_dirs.update(additional_includes)

    for file in files:
        dirname = os.path.dirname(file) or "."
        include_dirs.add(dirname)
    protoc_includes = ["-I" + dir for dir in include_dirs if dir]

    # Note that we could avoid creating temp files if protoc let us write to stdout
    # directly. this is currently only possible on Unix, by passing /dev/stdout as
    # the file name. Since there's no direct Windows equivalent, not counting
    # being creative with named pipes, special-casing this is not worth the effort.
    with tempfile.TemporaryDirectory() as tmpdir:
        outfile = os.path.join(tmpdir, "DESCRIPTOR_SET")
        subprocess.check_call(
            [PROTOC, "--descriptor_set_out={}".format(outfile)]
            + protoc_includes
            + files
        )
        with open(outfile, "rb") as f:
            return f.read()


def strip_leader(s, prefix):
    """Remove given prefix from underscored name."""
    leader = prefix + "_"
    if s.startswith(leader):
        return s[len(leader) :]
    else:
        return s


def import_statement_from_path(path):
    # separate leading dots
    dot_prefix = ""
    while path.startswith("."):
        dot_prefix += "."
        path = path[1:]

    # split on remaining dots
    split_path = path.rsplit(".", maxsplit=1)
    leader, import_name = split_path[:-1], split_path[-1]

    if leader:
        from_part = dot_prefix + leader
    elif dot_prefix:
        from_part = dot_prefix
    else:
        from_part = ""

    if from_part:
        return "from {} import {}".format(from_part, import_name)
    else:
        return "import {}".format(import_name)


class Descriptor:
    def __init__(self, data, message_type="MessageType", import_path="protobuf"):
        self.descriptor = descriptor_pb2.FileDescriptorSet()
        self.descriptor.ParseFromString(data)

        self.files = self.descriptor.file

        logging.debug("found {} files".format(len(self.files)))

        # find messages and enums
        self.messages = []
        self.enums = []
        for file in self.files:
            self.messages += file.message_type
            self.enums += file.enum_type
            for message in file.message_type:
                self._nested_types_from_message(message)

        if not self.messages and not self.enums:
            raise RuntimeError("No messages and no enums found.")

        self.message_types = self.find_message_types(message_type)
        self.protobuf_import = import_statement_from_path(import_path)

        self.out_dir = None

    def _nested_types_from_message(self, message):
        self.messages += message.nested_type
        self.enums += message.enum_type
        for nested in message.nested_type:
            self._nested_types_from_message(nested)

    def find_message_types(self, message_type):
        message_types = {}
        try:
            message_type_enum = next(
                enum for enum in self.enums if enum.name == message_type
            )
            for value in message_type_enum.value:
                name = strip_leader(value.name, message_type)
                message_types[name] = value.number

        except StopIteration:
            # No message type found. Oh well.
            logging.warning(
                "Message IDs not found under '{}'".format(args.message_type)
            )

        return message_types

    def parse_field(self, field):
        repeated = field.label == field.LABEL_REPEATED
        required = field.label == field.LABEL_REQUIRED
        if field.type == field.TYPE_MESSAGE:
            # ignore package path
            type_name = field.type_name.rsplit(".")[-1]
            proto_type = py_type = type_name
        else:
            try:
                proto_type, py_type = FIELD_TYPES[field.type]
            except KeyError:
                raise ValueError(
                    "Unknown field type {} for field {}".format(field.type, field.name)
                ) from None

        if repeated:
            py_type = "List[{}]".format(py_type)

        return ProtoField(
            name=field.name,
            number=field.number,
            proto_type=proto_type,
            py_type=py_type,
            repeated=repeated,
            required=required,
            orig=field,
        )

    def create_message_import(self, name):
        return "from .{0} import {0}".format(name)

    def process_message_imports(self, fields):
        imports = set(
            field.proto_type
            for field in fields
            if field.orig.type == field.orig.TYPE_MESSAGE
        )

        if len(imports) > 0:
            yield ""  # make isort happy
        for name in sorted(imports):
            yield self.create_message_import(name)

    def create_init_method(self, fields):
        # please keep the yields aligned
        # fmt: off
        ... # https://github.com/ambv/black/issues/385
        yield         "    def __init__("
        yield         "        self,"
        for field in fields:
            yield     "        {}: {} = None,".format(field.name, field.py_type)
        yield         "    ) -> None:"

        for field in fields:
            if field.repeated:
                yield "        self.{0} = {0} if {0} is not None else []".format(field.name)
            else:
                yield "        self.{0} = {0}".format(field.name)
        # fmt: on

    def create_fields_method(self, fields):
        # fmt: off
        yield "    @classmethod"
        yield "    def get_fields(cls):"
        yield "        return {"
        for field in fields:
            comments = []
            if field.required:
                comments.append("required")
            if field.orig.HasField("default_value"):
                comments.append("default={}".format(field.orig.default_value))

            if comments:
                comment = "  # " + " ".join(comments)
            else:
                comment = ""

            if field.repeated:
                flags = "p.FLAG_REPEATED"
            else:
                flags = "0"

            yield "            {num}: ('{name}', {type}, {flags}),{comment}".format(
                num=field.number,
                name=field.name,
                type=field.proto_type,
                flags=flags,
                comment=comment,
            )

        yield "        }"
        # fmt: on

    def process_message(self, message):
        logging.debug("Processing message {}".format(message.name))
        msg_id = self.message_types.get(message.name)

        # "from .. import protobuf as p"
        yield self.protobuf_import + " as p"

        fields = [self.parse_field(field) for field in message.field]

        yield from self.process_message_imports(fields)

        if any(field.repeated for field in fields):
            yield ""
            yield "if __debug__:"
            yield "    try:"
            yield "        from typing import List"
            yield "    except ImportError:"
            yield "        List = None  # type: ignore"

        yield ""
        yield ""
        yield "class {}(p.MessageType):".format(message.name)

        if msg_id is not None:
            yield "    MESSAGE_WIRE_TYPE = {}".format(msg_id)

        if fields:
            yield ""
            yield from self.create_init_method(fields)
            yield ""
            yield from self.create_fields_method(fields)

        if not fields and not msg_id:
            yield "    pass"

    def process_enum(self, enum):
        logging.debug("Processing enum {}".format(enum.name))

        for value in enum.value:
            # Remove type name from the beginning of the constant
            # For example "PinMatrixRequestType_Current" -> "Current"
            enum_prefix = enum.name
            name = value.name
            name = strip_leader(name, enum_prefix)

            # If type ends with *Type, but constant use type name without *Type, remove it too :)
            # For example "ButtonRequestType & ButtonRequest_Other" => "Other"
            if enum_prefix.endswith("Type"):
                enum_prefix, _ = enum_prefix.rsplit("Type", 1)
                name = strip_leader(name, enum_prefix)

            yield "{} = {}".format(name, value.number)

    def process_messages(self, messages):
        for message in sorted(messages, key=lambda m: m.name):
            self.write_to_file(message.name, self.process_message(message))

    def process_enums(self, enums):
        for enum in sorted(enums, key=lambda e: e.name):
            self.write_to_file(enum.name, self.process_enum(enum))

    def write_to_file(self, name, out):
        # Write generated sourcecode to given file
        logging.debug("Writing file {}.py".format(name))
        with open(os.path.join(self.out_dir, name + ".py"), "w") as f:
            f.write(AUTO_HEADER)
            f.write("# fmt: off\n")
            for line in out:
                f.write(line + "\n")

    def write_init_py(self):
        filename = os.path.join(self.out_dir, "__init__.py")
        with open(filename, "w") as init_py:
            init_py.write(AUTO_HEADER)
            init_py.write("# fmt: off\n\n")
            for message in sorted(self.messages, key=lambda m: m.name):
                init_py.write(self.create_message_import(message.name) + "\n")
            for enum in sorted(self.enums, key=lambda m: m.name):
                init_py.write("from . import {}\n".format(enum.name))

    def write_classes(self, out_dir, init_py=True):
        self.out_dir = out_dir
        self.process_messages(self.messages)
        self.process_enums(self.enums)
        if init_py:
            self.write_init_py()


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    # fmt: off
    parser.add_argument("proto", nargs="+", help="Protobuf definition files")
    parser.add_argument("-o", "--out-dir", help="Directory for generated source code")
    parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module")
    parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules")
    parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs")
    parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path")
    # fmt: on
    args = parser.parse_args()

    protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),)
    descriptor_proto = protoc(args.proto, protoc_includes)
    descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module)

    with tempfile.TemporaryDirectory() as tmpdir:
        descriptor.write_classes(tmpdir, not args.no_init_py)

        for filename in os.listdir(args.out_dir):
            pathname = os.path.join(args.out_dir, filename)
            try:
                with open(pathname, "r") as f:
                    if next(f, None) == AUTO_HEADER:
                        os.unlink(pathname)
            except Exception:
                pass

        for filename in os.listdir(tmpdir):
            src = os.path.join(tmpdir, filename)
            shutil.copy(src, args.out_dir)