#!/usr/bin/env python3

import subprocess
from pathlib import Path

import click

from elftools.construct import Struct, ULInt32, GreedyRange
from elftools.common.construct_utils import ULEB128
from elftools.elf.elffile import ELFFile

SYMBOL_TYPES = ("t", "w")  # text, weak

ROOT = Path(__file__).parent.parent.resolve()

FIRMWARE_ELF = ROOT / "core" / "build" / "firmware" / "firmware.elf"
elf = ELFFile(FIRMWARE_ELF.open("rb"))


def load_address_map():
    """Load address map from firmware ELF file using `nm`."""
    out = subprocess.check_output(
        args=["arm-none-eabi-nm", "--radix=d", "--demangle", FIRMWARE_ELF]
    )
    symbols = (line.decode().split(maxsplit=2) for line in out.splitlines())
    return {
        int(addr): name for addr, type, name in symbols if type.lower() in SYMBOL_TYPES
    }


def load_stack_sizes():
    """Load Rust stack sizes from firmware ELF section generated by `-Z emit-stack-sizes`."""
    stack_sizes = elf.get_section_by_name(".stack_sizes")

    Entries = GreedyRange(
        Struct(
            "Entry",
            ULInt32("symbol_addr"),
            ULEB128("stack_size"),
        )
    )
    return Entries.parse(stack_sizes.data())


@click.command()
def main():
    """Print Rust functions' stack size.

    See https://blog.japaric.io/stack-analysis/ for more details.
    """
    address_map = load_address_map()
    for entry in load_stack_sizes():
        symbol_name = address_map[entry.symbol_addr]
        print(f"{entry.stack_size}\t{symbol_name}")


if __name__ == "__main__":
    main()