#! /usr/bin/env python3
#
# Copyright (c) 2020-now by the Zeek Project. See LICENSE for details.

# Graphviz shapes: https://graphviz.gitlab.io/doc/info/shapes.html
# Graphviz attributes: https://www.graphviz.org/doc/info/attrs.html

import os
import os.path
import sys

import diagrams
from diagrams import Diagram, Edge


def tt(text, pt="14"):  # font size 14 is default
    return f'<font face="Courier" point-size="{pt}">{text}</font>' if text else ""


def tt_small(text, pt="8"):
    return f'<font face="Courier" point-size="{pt}">{text}</font>' if text else ""


def i(text):
    return f"<i>{text}</i>" if text else ""


def i_small(text, pt="8"):
    return (
        f'<font face="Helvetica" point-size="{pt}"><i>{text}</i></font>' if text else ""
    )


def bold(text):
    return f"<b>{text}</b>" if text else ""


def tt_bold(text, pt="14"):  # font size 14 is default
    return (
        f'<font face="Courier" point-size="{pt}"><b>{text}</b></font>' if text else ""
    )


def multiline(*lines):
    return "<br/>".join([f" {x} " for x in lines])


def mergeAttrs(attrs, defaults):
    attrs.update(defaults)
    return attrs


class Cluster(diagrams.Cluster):
    """Wrapper around `diagrams.Cluster` that records `input` and `output` nodes for later use."""

    def __init__(self, label=None, **kwargs):
        default_attrs = {
            "labeljust": "c",
        }

        mergeAttrs(kwargs, default_attrs)
        diagrams.Cluster.__init__(
            self, label=f"<{label}>" if label else "", graph_attr=kwargs
        )

    def __exit__(self, *args):
        super().__exit__(*args)

        try:
            self.input = input
        except NameError:
            pass

        try:
            self.output = output
        except NameError:
            pass


class Node(diagrams.Node):
    """Joint base class for all of our nodes, setting default attributes."""

    def __init__(self, label=None, xlabel=None, **kwargs):
        default_attrs = {
            "shape": "rectangle",
            "width": "1",
            "height": "0.7",
            "labelloc": "c",
            "fixedsize": "false",
        }

        if label:
            default_attrs["label"] = f"<{label}>"

        if xlabel:
            default_attrs["xlabel"] = f"<{xlabel}>"

        mergeAttrs(kwargs, default_attrs)
        diagrams.Node.__init__(self, **kwargs)


class CodeArea(Cluster):
    def __init__(self, label):
        Cluster.__init__(self, tt_bold(label), bgcolor="none")


def coloredNode(color):
    class ColoredNode(Node):
        def __init__(self, label, xlabel=""):
            Node.__init__(
                self,
                label=tt(label),
                xlabel=xlabel,
                style="filled",
                fillcolor=ColoredNode._color,
            )

        _color = color

    return ColoredNode


Component = coloredNode("lightgreen")
InputOutput = coloredNode("lightblue")
Library = coloredNode("orange")
Pass = coloredNode("lightyellow")

# Main

if len(sys.argv) != 2:
    print(f"Usage: {os.path.basename(sys.argv[0])} <output-filename-without-extension>")
    sys.exit(1)

output = sys.argv[1]

with Diagram(
    "",
    curvestyle="ortho",
    show=False,
    direction="TB",
    filename=output,
    graph_attr={"forcelabels": "true", "fontname": "Helvetica"},
    outformat=["pdf", "dot", "svg"],
):
    with CodeArea("hilti::Driver::compileUnits()") as process_ast:
        # Spicy pipeline

        with Cluster(i("Process all Spicy modules inside AST")) as process_ast_spicy:
            validate_pre = Pass(
                "Validate AST",
                xlabel=multiline(
                    tt_small("spicy::detail::validator::validatePre()"),
                    tt_small("hilti::detail::validator::validatePre()"),
                ),
            )

            with Cluster("Resolve AST") as resolve_ast_spicy:
                build_scopes = Pass(
                    "Build scopes",
                    xlabel=multiline(
                        tt_small("spicy::detail::scope_builder::build()"),
                        tt_small("hilti::detail::scope_builder::build()"),
                    ),
                )

                unify_types = Pass(
                    "Unify Types",
                    xlabel=multiline(
                        tt_small("spicy::type_unifier::detail::unifyType()"),
                        tt_small("hilti::type_unifier::detail::unifyType()"),
                    ),
                )

                resolve = Pass(
                    "Resolve Nodes",
                    xlabel=multiline(
                        tt_small("spicy::detail::resolver::resolve()"),
                        tt_small("hilti::detail::resolver::resolve()"),
                    ),
                )

                (input := build_scopes) >> unify_types >> (output := resolve)

                (output >> Edge(label="Iterate\\nuntil\\nstable") >> input)

            validate_post = Pass(
                "Validate AST",
                xlabel=multiline(
                    tt_small("spicy::detail::validator::validatePost()"),
                    tt_small("hilti::detail::validator::validatePost()"),
                ),
            )

            with CodeArea("Spicy CodeGen") as spicy_codegen:
                (
                    input := Pass(
                        multiline(
                            "Transform AST",
                            i_small("Translates Spicy code into HILTI code"),
                        ),
                        tt_small("spicy::detail::CodeGen::compileAST()"),
                    )
                )
                (output := input)

            (input := validate_pre) >> resolve_ast_spicy.input
            (
                resolve_ast_spicy.output
                >> validate_post
                >> (output := spicy_codegen.output)
            )

        (input := process_ast_spicy.input)
        (output := process_ast_spicy.output)

        # HILTI pipeline

        with Cluster(i("Process all HILTI modules inside AST")) as process_ast_hilti:
            validate_pre = Pass(
                "Validate AST",
                xlabel=tt_small("hilti::detail::validator::validatePre()"),
            )

            with Cluster("Resolve AST") as resolve_ast_spicy:
                build_scopes = Pass(
                    "Build scopes",
                    xlabel=tt_small("hilti::detail::scope_builder::build()"),
                )

                unify_types = Pass(
                    "Unify Types",
                    xlabel=tt_small("hilti::type_unifier::detail::unifyType()"),
                )

                resolve = Pass(
                    "Resolve Nodes",
                    xlabel=tt_small("hilti::detail::resolver::resolve()"),
                )

                (input := build_scopes) >> unify_types >> (output := resolve)
                output >> Edge(label="Iterate\\nuntil\\nstable") >> input

            validate_post = Pass(
                "Validate AST",
                xlabel=tt_small("hilti::detail::validator::validatePost()"),
            )

            validate_post_opt = Pass(
                "Validate AST",
                xlabel=tt_small("hilti::detail::validator::validatePost()"),
            )

            (input := validate_pre) >> resolve_ast_spicy.input

            optimize = Pass(
                "Optimize AST", xlabel=tt_small("hilti::detail::optimizer::optimize()")
            )

            (
                resolve_ast_spicy.output
                >> validate_post
                >> optimize
                >> (output := validate_post_opt)
            )

        (
            process_ast_spicy.output
            >> InputOutput(
                multiline(
                    "Transformed AST", i_small("Pure HILTI AST, no more Spicy code")
                )
            )
            >> process_ast_hilti.input
        )

        (input := process_ast_spicy.input)
        (output := process_ast_hilti.output)

    with CodeArea("{hilti,spicy}/include/compiler/detail/parser/*") as include_parser:
        parser_hilti = Component("hilti::detail::parser::Driver")
        parser_spicy = Component("spicy::detail::parser::Driver")

    with CodeArea("HILTI CodeGen") as hilti_codegen:
        (input := Component("hilti::detail::CodeGen"))
        (output := input)

    with CodeArea("hilti/include/compiler/jit.h") as hilti_jit:
        (
            input := Component(
                multiline("hilti::JIT", i_small("Spawns clang/GCC to compile C++ code"))
            )
        )
        (output := input)

    with CodeArea("C++ Code") as cxx_code:
        out1 = InputOutput("cxx::Unit<sub>1</sub>")
        outn = InputOutput("cxx::Unit<sub>n</sub>")
        out1 - Edge(label="...", constraint="false", color="transparent") - outn
        (input := [out1, outn])
        (output := input)

    with CodeArea("hilti/{include,src}/rt/*") as hilti_runtime:
        (input := Library("HILTI Runtime Library"))
        (output := input)

    with CodeArea("spicy/{include,src}/rt/*") as spicy_runtime:
        (input := Library("Spicy Runtime Library"))
        (output := input)

    with CodeArea("Compiled C++ Code") as compiled_cxx_code:
        out1 = InputOutput("cxx::CompiledUnit<sub>1</sub>")
        outn = InputOutput("cxx::CompiledUnit<sub>n+1</sub>")
        out1 - Edge(label="...", constraint="false", color="transparent") - outn
        (input := [out1, outn])
        (output := input)

    # Top-level diagram

    InputOutput(tt("*.hlt")) >> parser_hilti
    InputOutput(tt("*.spicy")) >> parser_spicy

    [parser_hilti, parser_spicy] >> InputOutput("Original AST") >> process_ast.input

    process_ast.output >> InputOutput("Final AST") >> hilti_codegen.input

    (
        cxx_code.output
        >> Component("hilti::Unit::link()")
        >> InputOutput("cxx::Unit<sub>linker</sub>")
        >> hilti_jit.input
    )

    (
        hilti_codegen.output
        >> cxx_code.output
        >> hilti_jit.output
        >> compiled_cxx_code.output
        >> (linker := Component("System Linker"))
        >> InputOutput("Final executable code")
    )

    hilti_runtime.output >> linker
    spicy_runtime.output >> linker
