from multiprocessing import cpu_count
import os  # allow imports
import sys

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for module_folder in ["algorithms", "benchmark", "zx_dfs", "statistics", "bin"]:
    sys.path.append(f"{parent_dir}/{module_folder}")
else:
    sys.path.append(f"{parent_dir}")

import argparse
from glob import glob

from typing import Callable

from benchmark.run_benchmark import benchmark_run
from algorithms.pyzx import (
    unoptimized,
    pyzx_full_reduce,
    pyzx_teleport_reduce,
    pyzx_scalar_reduce,
)


from algorithms.idfs import idfs_no_colour_cycle as idfs
from algorithms.dfs import dfs_no_colour_cycle as dfs
from zx_dfs.local_elimination import LocalElimination
from zx_dfs.local_dfs import LocalDFS
from zx_dfs.metric import MetricTcount, MetricEdge


def cli_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--fast", action="store_true")
    parser.add_argument("-ff", "--files", nargs="+", type=str)
    parser.add_argument("-c", "--cores", type=int, default=cpu_count())
    parser.add_argument("-t", "--time", type=int, default=60 * 60 * 1.5)

    return parser


if __name__ == "__main__":
    # benchmark to run
    l_func_opt: list[Callable] = [
        dfs,
        # idfs,
        # pyzx_full_reduce,
        # pyzx_teleport_reduce,
        # pyzx_scalar_reduce,
    ]
    parser = cli_parser()
    args = parser.parse_args()

    # check if smaller benchmark was selected; meaning less circuits
    if args.fast:
        l_circuits = [
            "../circuits/xor5d1.qasm",
            "../circuits/graycode6.qasm",
            "../circuits/ham3tc.qasm",
            "../circuits/mod5d4.qasm",
            "../circuits/mod5mils.qasm",
            "../circuits/3_17tc.qasm",
            "../circuits/mod5d1.qasm",
            "../circuits/mod5d2.qasm",
            "../circuits/hwb4-11-21.qasm",
            "../circuits/hwb4-11-23.qasm",
            "../circuits/or5d1.qasm",
            "../circuits/4_49-12-32.qasm",
            "../circuits/mspk_4_49_12.qasm",
            "../circuits/mspk_hwb4_12.qasm",
            "../circuits/mspk_hwb4_13.qasm",
            "../circuits/mspk_nth_primes4_12.qasm",
            "../circuits/mspk_4_49_13.qasm",
            "../circuits/2of5d2.qasm",
            "../circuits/gf2^3mult_11_47.qasm",
            "../circuits/mspk_nth_primes4_13.qasm",
        ]
    elif args.files:
        l_circuits = args.files
    else:
        l_circuits = glob("../circuits/*.qasm")

    # run benchmark
    benchmark_run(
        l_circuits=l_circuits,
        l_func_opt=l_func_opt,
        latex=False,
        l_metric=[MetricTcount(), MetricEdge()],
        max_duration=args.time,
        workers=args.cores,
    )
