"""
BENCHMARK
"""

import os  # allow imports
import sys

from zx_dfs.dfs import AllResults

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for module_folder in ["algorithms", "benchmark", "zx_dfs", "statistics", "bin"]:
    sys.path.append(f"{parent_dir}/{module_folder}")
else:
    sys.path.append(f"{parent_dir}")

from .statistics import get_circuit_name, get_circuit_statistics

import qiskit as qk
from qiskit import QuantumCircuit
from typing import Callable
import pandas as pd
from pandas import DataFrame
from glob import glob
from tqdm.contrib.concurrent import process_map
import tqdm
import multiprocessing as mp
from multiprocessing import pool, cpu_count
import subprocess
from dataclasses import dataclass, fields
from zx_dfs.metric import Metric, MetricTcount

import datetime as DT


from algorithms.pyzx import (
    unoptimized,
    pyzx_full_reduce,
    pyzx_teleport_reduce,
    pyzx_scalar_reduce,
)

from algorithms.dfs import dfs

from zx_dfs.dfs import AllResults

import dill


@dataclass
class LogBenchmark:
    qc: str
    func_opt: str
    start: DT.datetime = DT.datetime.now()
    stop: DT.datetime | None = None
    expired: DT.timedelta | None = None
    result: AllResults | None = None

    def done(self) -> None:
        self.stop = DT.datetime.now()
        self.expired = self.stop - self.start

    def __str__(self) -> str:
        str_return: str = (
            f"===\nQC:\t\t{self.qc}\nFunc_Opt:\t{self.func_opt}\nStart:\t\t{self.start}\nStop:\t\t{self.stop}\nElapsed:\t{self.expired}\n==="
        )
        return str_return


"""
run arbitray benchmark function for one quantum circuit
"""


def benchmark_one(
    f_qasm: str,
    flush_write: bool,
    fname: str,
    func_opt: Callable,
    l_metric: list[str] = [MetricTcount()],
    max_duration: int = 60 * 60 * 6,
) -> list | None:

    circ_name: str = get_circuit_name(f_qasm)

    try:
        qc: QuantumCircuit = qk.QuantumCircuit.from_qasm_file(f_qasm)
    except FileNotFoundError:
        print(f"{f_qasm} not found", file=sys.stderr)
        return None

    log: LogBenchmark = LogBenchmark(circ_name, func_opt.__name__)

    # run benchmark
    all_results = func_opt(
        qc,
        circ_name=circ_name,
        metric=l_metric,
        func_opt=func_opt.__name__,
        max_duration=max_duration,
    )

    log.done()

    l_dic_opt = []

    for m in l_metric:
        metric_result = getattr(all_results, m.__str__())

        dic_opt: dict = get_circuit_statistics(
            qc=metric_result.qc,
            c=metric_result.circuit,
            g=metric_result.graph,
            circuit_name=circ_name,
            func_opt=func_opt,
            metric=m.__str__(),
        )

        dic_opt["t"] = log.expired.total_seconds()

        # only add number of nodes and leafs to dictionary if they are present; meaning func_opt is dfs
        if hasattr(all_results, "n_nodes"):
            dic_opt["n_nodes"] = all_results.n_nodes
        if hasattr(all_results, "n_leafs"):
            dic_opt["n_leafs"] = all_results.n_leafs
        if hasattr(all_results, "max_depth"):
            dic_opt["max_depth"] = all_results.max_depth
        if hasattr(metric_result, "node"):
            dic_opt["node"] = metric_result.node
        if hasattr(metric_result, "depth"):
            dic_opt["depth"] = metric_result.depth
        if hasattr(metric_result, "expired_time"):
            dic_opt["expired_time"] = metric_result.expired_time
        # if hasattr(all_results, "optimal"):
        #    if dic_opt[m.__str__()] == 0:
        #        dic_opt["optimal"] = True
        #    else:
        #        dic_opt["optimal"] = all_results.optimal

        l_dic_opt.append(dic_opt)

    if flush_write:
        df_results: DataFrame = pd.DataFrame(l_dic_opt)
        df_results.to_csv(
            f"_{fname}_checkpoint.csv", index=False, mode="a"
        )  # write csv

    return l_dic_opt


"""
create a list of tuples that contains every benchmark function and circuit combinations
"""


def schedule_benchmark_run(
    l_circuits: list[str],
    flush: bool,
    fname: str,
    l_func_opt: list[Callable],
    l_metric: list[str],
    max_duration: int,
) -> list[tuple]:
    l_runs: list = []

    for circ in l_circuits:
        for func in l_func_opt:
            l_runs.append((circ, flush, fname, func, l_metric, max_duration))

    return l_runs


"""
run benchmark for a given list of quantum circuits and benchmark function and return the results
"""


def benchmark_run(
    fname: str = "results",
    l_circuits: list[str] | None = None,
    latex: bool = False,
    l_func_opt: list[Callable] = [
        unoptimized,
        pyzx_full_reduce,
        pyzx_teleport_reduce,
        pyzx_scalar_reduce,
    ],
    l_metric: list[str] = [MetricTcount],
    max_duration: int = 60 * 60 * 6,
    workers: int = cpu_count(),
) -> DataFrame:
    if l_circuits is None:
        l_circuits: list[str] = glob(
            "../circuits/*.qasm"
        )  # circuits used by quarl that pyzx could convert

    l_runs = schedule_benchmark_run(
        l_circuits, True, fname, l_func_opt, l_metric, max_duration
    )

    # generate list of individual values used to freed process map
    s_circ_name, s_flush, s_fname, s_func_opt, s_metric, s_max_duration = zip(*l_runs)

    # run benchmark
    l_results: list = process_map(
        benchmark_one,
        s_circ_name,
        s_flush,
        s_fname,
        s_func_opt,
        s_metric,
        s_max_duration,
        max_workers=workers,
        leave=True,
        position=0,
    )

    # with mp.Pool(processes=mp.cpu_count() + 4) as P:
    #    l_results: list = P.starmap(
    #        benchmark_one,
    #        tqdm.tqdm(l_runs, total=len(l_runs)),
    #    )

    # flatten list of circuits and individual circuits into one list
    l_results_flat: list = [
        v for item in l_results for v in (item if isinstance(item, list) else [item])
    ]

    l_results_flat[:] = [v for v in l_results_flat if v is not None]

    df_results: DataFrame = pd.DataFrame(
        l_results_flat
    ).drop_duplicates()  # drop duplicate circuits
    df_results.to_csv(fname + ".csv", index=False)  # write csv

    # write out results to latex table
    if latex:
        # sort dataframe by circuit name and optimization function
        df_results.sort_values(by=["circuit_name", "func_opt"]).to_latex(
            fname + ".tex", index=False, escape="\\", float_format="{:,.0f}".format
        )

        # geneate pdf
        subprocess.run(["pdflatex", "template.tex"], stdout=subprocess.DEVNULL)

    return df_results
