"""
Helper for handling `local_cache_attribute` size optimizations in `core/src` code.

Has the possibility to transfer comments into code changes and vice-versa.

Possible improvements (TODO):
- do not stop with the renaming process when inner function is encountered
    def abc(self):
        x = self.x  # local_cache_attribute
        ghi(x)
        y = x

        def ggg():
            ...

        ghi(self.x)
        y = self.x
- do not rename when the new name is already a global symbol
    import multisig
    def abc(msg):
        multisig = msg.multisig  # local_cache_attribute
        multisig.ask(multisig)  # ERROR
- do not rename two caches with the same name
    slice_view = aprime.slice_view  # local_cache_attribute
    slice_view = bprime.slice_view  # local_cache_attribute
"""

from __future__ import annotations

import sys
from pathlib import Path

import click

try:
    import libcst as cst
    import libcst.matchers as m
except ImportError:
    click.echo("please install libcst via: pip install libcst")
    sys.exit(1)


TRANSLATED_COMMENT_MATCHER = m.SimpleStatementLine(
    body=[m.Assign()],
    trailing_whitespace=m.TrailingWhitespace(
        comment=m.Comment(value="# local_cache_attribute")
    ),
)


def attr_to_list(attr: cst.Attribute) -> list[str]:
    if m.matches(attr, m.Attribute(value=m.Name(), attr=m.Name())):
        return [attr.value.value, attr.attr.value]
    if m.matches(attr, m.Attribute(value=m.Attribute(), attr=m.Name())):
        return attr_to_list(attr.value) + [attr.attr.value]
    raise ValueError("unexpected attr format")


class Unrenamer(cst.CSTTransformer):
    def __init__(self, module: cst.Module, simplify: bool) -> None:
        self.renamers: list[tuple[cst.Name, cst.Attribute]] = []
        self.module = module
        self.simplify = simplify

    def leave_SimpleStatementLine(
        self, node: cst.SimpleStatementLine, updated: cst.CSTNode
    ) -> cst.CSTNode | None:
        if not m.matches(updated, TRANSLATED_COMMENT_MATCHER):
            return updated
        assign: cst.Assign = updated.body[0]
        name: cst.Name = assign.targets[0].target
        value_attr: cst.Attribute = assign.value
        if not isinstance(value_attr, cst.Attribute):
            raise Exception(
                f"Unexpected non-attribute assignment: {self.module.code_for_node(assign)}"
            )
        self.renamers.append((name, value_attr))

        attr_list = attr_to_list(value_attr)
        attr_str = ".".join(attr_list)
        attr_longname = "_".join(attr_list)
        orig_name = name.value

        if self.simplify and orig_name == attr_longname:
            orig_name = attr_list[-1]

        if orig_name != attr_list[-1]:
            comment_str = f"{attr_str} -> {orig_name}"
        else:
            comment_str = attr_str

        return cst.EmptyLine(
            indent=True,
            comment=cst.Comment(f"# local_cache_attribute: {comment_str}"),
        )

    def leave_Name(self, node: cst.Name, updated: cst.Name) -> cst.CSTNode:
        for old_name, attr in self.renamers:
            if updated.deep_equals(old_name):
                return attr
        return updated

    def leave_FunctionDef_body(self, node: cst.FunctionDef) -> None:
        self.renamers.clear()


class Renamer(cst.CSTTransformer):
    def __init__(self, _module: cst.Module, _simplify: bool) -> None:
        self.renamers: list[tuple[list[str], cst.Name]] = []
        self.name_is_keyword = None

    def leave_EmptyLine(
        self, node: cst.EmptyLine, updated: cst.EmptyLine
    ) -> cst.CSTNode:
        if not m.matches(node, m.EmptyLine(comment=m.Comment())):
            return updated

        comment = node.comment.value
        if not comment.startswith("# local_cache_attribute: "):
            return updated

        value_str = comment[len("# local_cache_attribute: ") :]
        if " -> " in value_str:
            value_str, target_str = value_str.split(" -> ", maxsplit=1)
        else:
            target_str = None
        attr = value_str.split(".")
        name = cst.Name(target_str or attr[-1])

        statement = cst.SimpleStatementLine(
            body=[
                cst.Assign(
                    targets=[cst.AssignTarget(target=name)],
                    value=self.process_attribute(attr),
                )
            ],
            trailing_whitespace=cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(value="  "),
                comment=cst.Comment(value="# local_cache_attribute"),
            ),
        )
        self.renamers.append((attr, name))
        return statement

    def visit_Name(self, node: cst.Name) -> None:
        if node is self.name_is_keyword:
            return
        for _, name in self.renamers:
            if node.deep_equals(name):
                raise Exception(f"Name {name.value} already exists in the function")

    def visit_Arg_keyword(self, node: cst.Arg) -> None:
        self.name_is_keyword = node.keyword

    def leave_Arg_keyword(self, node: cst.Arg) -> None:
        self.name_is_keyword = None

    def process_attribute(self, node: list[str]) -> cst.BaseExpression:
        assert node
        if len(node) == 1:
            return cst.Name(value=node[0])
        for old_attr, name in self.renamers:
            if node == old_attr:
                return name
        return cst.Attribute(
            value=self.process_attribute(node[:-1]), attr=cst.Name(value=node[-1])
        )

    def visit_Attribute(self, node: cst.Attribute) -> bool:
        # prevent recursing into attribute chains so that we can recurse manually
        # in leave_attribute
        return False

    def leave_Attribute(
        self, node: cst.Attribute, updated: cst.Attribute
    ) -> cst.CSTNode:
        assert node.deep_equals(updated)
        try:
            return self.process_attribute(attr_to_list(updated))
        except ValueError:
            return updated

    def leave_FunctionDef_body(self, node: cst.FunctionDef) -> None:
        self.renamers.clear()


def transform_file(
    path: Path, transformer: type[cst.CSTTransformer], simplify: bool
) -> None:
    try:
        module = cst.parse_module(path.read_text())
        modified = module.visit(transformer(module, simplify))
        if modified.code != module.code:
            path.write_text(modified.code)
            click.echo(f"Successfully converted {path}")
    except Exception as e:
        click.echo(f"Failed to convert {path}: {e}")


@click.command()
@click.argument(
    "filename", nargs=-1, type=click.Path(exists=True, file_okay=True, dir_okay=True)
)
@click.option("-r", "--reverse", is_flag=True)
@click.option("-s", "--simplify", is_flag=True)
def main(filename: list[str], reverse: bool, simplify: bool) -> None:
    if not filename:
        raise click.ClickException("No files specified")

    if reverse:
        transformer = Unrenamer
    else:
        transformer = Renamer

    for name in filename:
        path = Path(name)
        if path.is_dir():
            for subpath in path.glob("**/*.py"):
                transform_file(subpath, transformer, simplify)
        else:
            transform_file(path, transformer, simplify)


if __name__ == "__main__":
    main()