feat(snippet): converter for local_cache_attribute

pull/2633/head
matejcik 2 years ago committed by matejcik
parent 579ee06b6f
commit 1a46c7dfdc

@ -0,0 +1,225 @@
"""
Helper for handling `local_cache_attribute` size optimizations in `core/src` code.
Has the possibility to transfer comments into code changes and vice-versa.
Possible improvements (TODO):
- do not stop with the renaming process when inner function is encountered
def abc(self):
x = self.x # local_cache_attribute
ghi(x)
y = x
def ggg():
...
ghi(self.x)
y = self.x
- do not rename when the new name is already a global symbol
import multisig
def abc(msg):
multisig = msg.multisig # local_cache_attribute
multisig.ask(multisig) # ERROR
- do not rename two caches with the same name
slice_view = aprime.slice_view # local_cache_attribute
slice_view = bprime.slice_view # local_cache_attribute
"""
from __future__ import annotations
import sys
from pathlib import Path
import click
try:
import libcst as cst
import libcst.matchers as m
except ImportError:
click.echo("please install libcst via: pip install libcst")
sys.exit(1)
TRANSLATED_COMMENT_MATCHER = m.SimpleStatementLine(
body=[m.Assign()],
trailing_whitespace=m.TrailingWhitespace(
comment=m.Comment(value="# local_cache_attribute")
),
)
def attr_to_list(attr: cst.Attribute) -> list[str]:
if m.matches(attr, m.Attribute(value=m.Name(), attr=m.Name())):
return [attr.value.value, attr.attr.value]
if m.matches(attr, m.Attribute(value=m.Attribute(), attr=m.Name())):
return attr_to_list(attr.value) + [attr.attr.value]
raise ValueError("unexpected attr format")
class Unrenamer(cst.CSTTransformer):
def __init__(self, module: cst.Module, simplify: bool) -> None:
self.renamers: list[tuple[cst.Name, cst.Attribute]] = []
self.module = module
self.simplify = simplify
def leave_SimpleStatementLine(
self, node: cst.SimpleStatementLine, updated: cst.CSTNode
) -> cst.CSTNode | None:
if not m.matches(updated, TRANSLATED_COMMENT_MATCHER):
return updated
assign: cst.Assign = updated.body[0]
name: cst.Name = assign.targets[0].target
value_attr: cst.Attribute = assign.value
if not isinstance(value_attr, cst.Attribute):
raise Exception(
f"Unexpected non-attribute assignment: {self.module.code_for_node(assign)}"
)
self.renamers.append((name, value_attr))
attr_list = attr_to_list(value_attr)
attr_str = ".".join(attr_list)
attr_longname = "_".join(attr_list)
orig_name = name.value
if self.simplify and orig_name == attr_longname:
orig_name = attr_list[-1]
if orig_name != attr_list[-1]:
comment_str = f"{attr_str} -> {orig_name}"
else:
comment_str = attr_str
return cst.EmptyLine(
indent=True,
comment=cst.Comment(f"# local_cache_attribute: {comment_str}"),
)
def leave_Name(self, node: cst.Name, updated: cst.Name) -> cst.CSTNode:
for old_name, attr in self.renamers:
if updated.deep_equals(old_name):
return attr
return updated
def leave_FunctionDef_body(self, node: cst.FunctionDef) -> None:
self.renamers.clear()
class Renamer(cst.CSTTransformer):
def __init__(self, _module: cst.Module, _simplify: bool) -> None:
self.renamers: list[tuple[list[str], cst.Name]] = []
self.name_is_keyword = None
def leave_EmptyLine(
self, node: cst.EmptyLine, updated: cst.EmptyLine
) -> cst.CSTNode:
if not m.matches(node, m.EmptyLine(comment=m.Comment())):
return updated
comment = node.comment.value
if not comment.startswith("# local_cache_attribute: "):
return updated
value_str = comment[len("# local_cache_attribute: ") :]
if " -> " in value_str:
value_str, target_str = value_str.split(" -> ", maxsplit=1)
else:
target_str = None
attr = value_str.split(".")
name = cst.Name(target_str or attr[-1])
statement = cst.SimpleStatementLine(
body=[
cst.Assign(
targets=[cst.AssignTarget(target=name)],
value=self.process_attribute(attr),
)
],
trailing_whitespace=cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(value=" "),
comment=cst.Comment(value="# local_cache_attribute"),
),
)
self.renamers.append((attr, name))
return statement
def visit_Name(self, node: cst.Name) -> None:
if node is self.name_is_keyword:
return
for _, name in self.renamers:
if node.deep_equals(name):
raise Exception(f"Name {name.value} already exists in the function")
def visit_Arg_keyword(self, node: cst.Arg) -> None:
self.name_is_keyword = node.keyword
def leave_Arg_keyword(self, node: cst.Arg) -> None:
self.name_is_keyword = None
def process_attribute(self, node: list[str]) -> cst.BaseExpression:
assert node
if len(node) == 1:
return cst.Name(value=node[0])
for old_attr, name in self.renamers:
if node == old_attr:
return name
return cst.Attribute(
value=self.process_attribute(node[:-1]), attr=cst.Name(value=node[-1])
)
def visit_Attribute(self, node: cst.Attribute) -> bool:
# prevent recursing into attribute chains so that we can recurse manually
# in leave_attribute
return False
def leave_Attribute(
self, node: cst.Attribute, updated: cst.Attribute
) -> cst.CSTNode:
assert node.deep_equals(updated)
try:
return self.process_attribute(attr_to_list(updated))
except ValueError:
return updated
def leave_FunctionDef_body(self, node: cst.FunctionDef) -> None:
self.renamers.clear()
def transform_file(
path: Path, transformer: type[cst.CSTTransformer], simplify: bool
) -> None:
try:
module = cst.parse_module(path.read_text())
modified = module.visit(transformer(module, simplify))
if modified.code != module.code:
path.write_text(modified.code)
click.echo(f"Successfully converted {path}")
except Exception as e:
click.echo(f"Failed to convert {path}: {e}")
@click.command()
@click.argument(
"filename", nargs=-1, type=click.Path(exists=True, file_okay=True, dir_okay=True)
)
@click.option("-r", "--reverse", is_flag=True)
@click.option("-s", "--simplify", is_flag=True)
def main(filename: list[str], reverse: bool, simplify: bool) -> None:
if not filename:
raise click.ClickException("No files specified")
if reverse:
transformer = Unrenamer
else:
transformer = Renamer
for name in filename:
path = Path(name)
if path.is_dir():
for subpath in path.glob("**/*.py"):
transform_file(subpath, transformer, simplify)
else:
transform_file(path, transformer, simplify)
if __name__ == "__main__":
main()
Loading…
Cancel
Save