import os  # allow imports
import sys

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for module_folder in ["algorithms", "benchmark", "zx_dfs", "statistics", "bin"]:
    sys.path.append(f"{parent_dir}/{module_folder}")
else:
    sys.path.append(f"{parent_dir}")

import argparse
from glob import glob

from typing import Callable

from benchmark.run_benchmark import benchmark_run
from algorithms.pyzx import (
    unoptimized,
    pyzx_full_reduce,
    pyzx_scalar_reduce,
    pyzx_teleport_reduce,
)

from algorithms.dfs import (
    dfs_no_colour_cycle,
    dfs_no_colour_cycle_limit_colour_change,
    dfs_limit_colour_change,
    dfs_no_colour_cycle_max_depth,
    dfs_strict_convergence,
    dfs_no_colour_cycle_kill_non_circuits,
    dfs,
)

from algorithms.local_elimination import (
    local_elimination_no_colour_cycle,
    local_elimination_no_colour_cycle_limit_colour_change,
    local_elimination_limit_colour_change,
    local_elimination_no_colour_cycle_max_depth,
    local_elimination_strict_convergence,
    local_elimination_no_colour_cycle_kill_non_circuits,
    local_elimination,
)

from algorithms.idfs import (
    idfs_no_colour_cycle,
    idfs_no_colour_cycle_limit_colour_change,
    idfs_limit_colour_change,
    idfs_no_colour_cycle_max_depth,
    idfs_strict_convergence,
    idfs_no_colour_cycle_kill_non_circuits,
)

from zx_dfs.metric import MetricTcount, MetricEdge, MetricTwoQubit


def cli_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--fast", action="store_true")

    return parser


if __name__ == "__main__":
    # benchmark to run
    l_func_opt: list[Callable] = [
        # local_elimination_no_colour_cycle_limit_colour_change,
        local_elimination_no_colour_cycle,
        # local_elimination_limit_colour_change,
        # local_elimination_no_colour_cycle_max_depth,
        # local_elimination_no_colour_cycle_kill_non_circuits,
    ]
    parser = cli_parser()
    args = parser.parse_args()

    # check if smaller benchmark was selected; meaning less circuits
    if args.fast:
        l_circuits = [
            "../circuits/xor5d1.qasm",
            "../circuits/graycode6.qasm",
            "../circuits/ham3tc.qasm",
            "../circuits/mod5d4.qasm",
            "../circuits/mod5mils.qasm",
            "../circuits/3_17tc.qasm",
            "../circuits/mod5d1.qasm",
            "../circuits/mod5d2.qasm",
            "../circuits/hwb4-11-21.qasm",
            "../circuits/hwb4-11-23.qasm",
            "../circuits/or5d1.qasm",
            "../circuits/4_49-12-32.qasm",
            "../circuits/mspk_4_49_12.qasm",
            "../circuits/mspk_hwb4_12.qasm",
            "../circuits/mspk_hwb4_13.qasm",
            "../circuits/mspk_nth_primes4_12.qasm",
            "../circuits/mspk_4_49_13.qasm",
            "../circuits/2of5d2.qasm",
            "../circuits/gf2^3mult_11_47.qasm",
            "../circuits/mspk_nth_primes4_13.qasm",
        ]
    else:
        l_circuits = glob("../circuits/*.qasm")

    # run benchmark

    benchmark_run(
        l_circuits=l_circuits,
        l_func_opt=l_func_opt,
        latex=False,
        l_metric=[MetricTcount(), MetricEdge()],
        max_duration=60 * 60 * 6,
    )
