#!/usr/bin/env python3
# Converts Google's protobuf python definitions of TREZOR wire messages
# to plain-python objects as used in TREZOR Core and python-trezor

import argparse
import importlib
import logging
import os
import sys
from collections import namedtuple

ProtoField = namedtuple('ProtoField', 'name, number, proto_type, py_type, repeated, required, orig')


def parse_field(number, field):
    FIELD_TYPES = {
        field.TYPE_UINT64:  ('p.UVarintType', 'int'),
        field.TYPE_UINT32:  ('p.UVarintType', 'int'),
        field.TYPE_ENUM:    ('p.UVarintType', 'int'),
        field.TYPE_SINT32:  ('p.SVarintType', 'int'),
        field.TYPE_SINT64:  ('p.SVarintType', 'int'),
        field.TYPE_STRING:  ('p.UnicodeType', 'str'),
        field.TYPE_BOOL:    ('p.BoolType', 'bool'),
        field.TYPE_BYTES:   ('p.BytesType', 'bytes'),
    }
    repeated = (field.label == field.LABEL_REPEATED)
    required = (field.label == field.LABEL_REQUIRED)
    if field.type == field.TYPE_MESSAGE:
        proto_type = py_type = field.message_type.name
    else:
        try:
            proto_type, py_type = FIELD_TYPES[field.type]
        except KeyError:
            raise ValueError("Unknown field type %d for field %s" % (field.type, field.name)) from None

    if repeated:
        py_type = "List[%s]" % py_type

    return ProtoField(
        name=field.name,
        number=number,
        proto_type=proto_type,
        py_type=py_type,
        repeated=repeated,
        required=required,
        orig=field,
    )


def import_pb2(name):
    return importlib.import_module("%s_pb2" % name)


def create_message_import(name):
    return "from .%s import %s" % (name, name)


def remove_from_start(s, prefix):
    if s.startswith(prefix):
        return s[len(prefix):]
    else:
        return s


def process_message_imports(descriptor):
    imports = set()

    for field in descriptor.fields:
        if field.type == field.TYPE_MESSAGE:
            imports.add(field.message_type.name)

    for name in sorted(imports):
        yield create_message_import(name)


def create_init_method(fields):
    yield         "    def __init__("
    yield         "        self,"
    for field in fields[:-1]:
        yield     "        %s: %s = None," % (field.name, field.py_type)
    # last field must not have a traling comma
    yield         "        %s: %s = None" % (fields[-1].name, fields[-1].py_type)
    yield         "    ) -> None:"

    for field in fields:
        if field.repeated:
            yield "        self.{0} = {0} if {0} is not None else []".format(field.name)
        else:
            yield "        self.{0} = {0}".format(field.name)


def process_message(descriptor, protobuf_module, msg_id, is_upy):
    logging.debug("Processing message %s", descriptor.name)

    if is_upy:
        yield "import protobuf as p"
    else:
        yield "from .. import protobuf as p"

    fields = list(parse_field(number, field)
                  for number, field
                  in descriptor.fields_by_number.items())

    if any(field.repeated for field in fields):
        yield "if __debug__:"
        yield "    try:"
        yield "        from typing import List"
        yield "    except ImportError:"
        yield "        List = None"

    yield from process_message_imports(descriptor)

    yield ""
    yield ""
    yield "class %s(p.MessageType):" % descriptor.name

    if msg_id is not None:
        yield "    MESSAGE_WIRE_TYPE = %d" % msg_id

    if fields:
        yield "    FIELDS = {"
        for field in fields:
            comments = []
            if field.required:
                comments.append('required')
            if field.orig.has_default_value:
                comments.append("default=%s" % repr(field.orig.default_value))

            if comments:
                comment = "  # %s" % ' '.join(comments)
            else:
                comment = ''

            if field.repeated:
                flags = 'p.FLAG_REPEATED'
            else:
                flags = '0'

            yield "        %d: ('%s', %s, %s),%s" % (field.number, field.name, field.proto_type, flags, comment)

        yield "    }"
        yield ""
        yield from create_init_method(fields)

    if not fields and not msg_id:
        yield "    pass"


def process_enum(descriptor, is_upy):
    logging.debug("Processing enum %s", descriptor.name)

    for name, value in descriptor.values_by_name.items():
        # Remove type name from the beginning of the constant
        # For example "PinMatrixRequestType_Current" -> "Current"
        enum_prefix = descriptor.name
        name = remove_from_start(name, "%s_" % enum_prefix)

        # If type ends with *Type, but constant use type name without *Type, remove it too :)
        # For example "ButtonRequestType & ButtonRequest_Other" => "Other"
        if enum_prefix.endswith("Type"):
            enum_prefix, _ = enum_prefix.rsplit("Type", 1)
            name = remove_from_start(name, "%s_" % enum_prefix)

        yield "%s = %s" % (name, value.number)


def process_file(descriptor, protobuf_module, genpath, modlist, is_upy):
    logging.info("Processing module %s", descriptor.name)

    msg_types = import_pb2('messages').MessageType

    for name, message_descriptor in sorted(descriptor.message_types_by_name.items()):
        # Find message type for given class
        try:
            msg_id = msg_types.Value("MessageType_%s" % name)
        except ValueError:
            msg_id = None

        out = process_message(message_descriptor, protobuf_module, msg_id, is_upy)
        write_to_file(genpath, name, out)
        if modlist:
            modlist.write(create_message_import(name) + "\n")

    for name, enum_descriptor in descriptor.enum_types_by_name.items():
        out = process_enum(enum_descriptor, is_upy)
        write_to_file(genpath, name, out)
        if modlist:
            modlist.write("from . import %s\n" % name)


def write_to_file(genpath, t, out):
    # Write generated sourcecode to given file
    with open(os.path.join(genpath, "%s.py" % t), 'w') as f:
        f.write("# Automatically generated by pb2py\n")
        for line in out:
            f.write(line + "\n")


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('module', help="Name of module to generate")
    parser.add_argument('genpath', help="Directory for generated source code")
    parser.add_argument('-P', '--protobuf-module', default="protobuf", help="Name of protobuf module")
    parser.add_argument('-l', '--modlist', type=argparse.FileType('a'), help="Generate list of modules")
    parser.add_argument('-p', '--protopath', type=str, help="Path to search for pregenerated Google's python sources")
    parser.add_argument('-m', '--micropython', action='store_true', help="Use micropython-favoured source code")
    args = parser.parse_args()

    if args.protopath:
        sys.path.append(args.protopath)

    # This must be done after sys.path.append
    module = import_pb2(args.module)

    process_file(module.DESCRIPTOR, args.protobuf_module, args.genpath, args.modlist, args.micropython)