# https://en.wikipedia.org/wiki/⇈  libraries
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy as sp
import scienceplots
import re
from datetime import datetime, timedelta
from itertools import chain


# types
from pandas import DataFrame


# code
class Plotter:
    def __init__(self, path_to_data: str | list[str]) -> None:
        # data import
        if type(path_to_data) is str:
            self.path_to_data: list[str] = [path_to_data]
            self.data: DataFrame = self._import_csv(path_to_data)

            # commonly used variables
            self.algorithms: list[str] = self._get_algorithms(self.data)
            self.circuits: list[str] = self._get_circuits(self.data)

        elif type(path_to_data) is list:
            self.path_to_data: list[str] = path_to_data

            self.algorithms: list[str] = []
            self.circuits: list[str] = []

            for f in self.path_to_data:
                data: DataFrame = self._import_csv(f)

                # commonly used variables
                self.algorithms.append(self._get_algorithms(data))
                self.circuits.append(self._get_circuits(data))

            self.algorithms = list(set(chain.from_iterable(self.algorithms)))
            self.circuits = list(set(chain.from_iterable(self.circuits)))

        # mpl specific
        self._setup_environment()
        pass

    def _setup_environment(self) -> None:
        mpl.use("TkAgg")  # interactive backend
        plt.style.use(["science", "high-vis", "no-latex"])
        plt.rcParams["figure.figsize"] = (11.69, 8.27)

    def _import_csv(
        self, fname: str, b_check_point: bool = False, header_skip: int = 4
    ) -> DataFrame:
        # import final result csv; only has one header in the beginning of the file
        if not b_check_point:
            df_raw: DataFrame = pd.read_csv(fname)
        else:
            # read csv with python file reader
            with open(fname, "r") as f:
                l_lines: list[str] = [
                    line.strip() for line in f.readlines()
                ]  # read lines to list and get rid of white space
                num_lines: int = len(l_lines)  # number of lines

            l_exps: list[dict] = []  # list of all dictionarys for each data line

            # we need to split files every n lines at the header, each junk of data have different amount of fields;
            # chunks of data between headers, identified by line numbers
            for i in range(0, num_lines, header_skip):
                header: list[str] = l_lines[i].split(",")
                data: list[list[str]] = [
                    l_lines[j].split(",") for j in range(i + 1, i + header_skip)
                ]

                # create a dictionary with header and correspondence value for each data line
                for entry in data:
                    dic_tmp: dict[str, str] = {}
                    for j, header_entry in enumerate(header):
                        dic_tmp[header_entry] = entry[j]

                    l_exps.append(dic_tmp)

            df_raw: DataFrame = pd.DataFrame(l_exps)

        # cast dataframe
        df_raw = self._cast_dataframe(df_raw)

        return df_raw

    # cast dataframe to use correct datatype for individual columns
    def _cast_dataframe(self, df: DataFrame) -> DataFrame:
        # possible casts
        l_timestamps = ["timestamp"]
        l_strings = ["circuit_name", "func_opt", "metric"]
        l_bool = ["optimal"]

        # get generate nans if empty string is encountered
        df = df.replace(r"^\s*$", np.nan, regex=True)

        # timestamps
        for ts in l_timestamps:
            if ts in df:
                df[ts] = pd.to_datetime(df[ts])

        # strings
        for s in l_strings:
            if s in df:
                df[s] = df[s].astype(dtype=str)

        # bool
        for b in l_bool:
            if b in df:
                df[b] = df[b].astype(dtype=bool)

        # numeric
        l_numeric = [
            i for i in list(df) if i not in l_timestamps + l_strings + l_bool
        ]  # all non-explicit columns are of a numeric data type
        for n in l_numeric:
            if n in df:
                df[n] = pd.to_numeric(df[n])

        return df

    def _get_circuits(self, df: DataFrame) -> list[str]:
        l_circuits: list[str] = df["circuit_name"].drop_duplicates().tolist()
        return l_circuits

    def _get_algorithms(self, df: DataFrame) -> list[str]:
        l_regex_patterns: list[str] = [
            # r"^pyzx*",
            # r"^unoptimized",
            r"^idfs",
            r"^dfs",
        ]
        l_algorithms: list[str] = df["func_opt"].drop_duplicates().tolist()

        l_final_algorithms: list[str] = []

        for exps in l_regex_patterns:
            p = re.compile(exps)
            for algorithm in l_algorithms:
                m = p.findall(algorithm)
                if m:
                    l_final_algorithms.append(m[-1])

        return list(set(l_final_algorithms))

    def _get_pruning_conditions(self) -> list[str]:
        pass


class PlotterPruningConditions(Plotter):

    def _select_data(self, df: DataFrame, pattern: str = r"^idfs") -> DataFrame:
        df_match: DataFrame = df.loc[df["func_opt"].str.contains(pattern, regex=True)]
        df_match = df_match[
            ["circuit_name", "func_opt", "n_nodes", "n_leafs"]
        ].drop_duplicates()
        return df_match

    def _node_leaf_ratio(self, df: DataFrame) -> DataFrame:
        df["node-leaf-ratio"]: DataFrame = df["n_nodes"] / df["n_leafs"]
        return df

    def plot(self):

        df_select: DataFrame = self._select_data(self.data)
        df_ratio: DataFrame = self._node_leaf_ratio(df_select)

        for col in df_ratio["func_opt"].drop_duplicates():
            print(
                col, np.mean(df_ratio[df_ratio["func_opt"] == col]["node-leaf-ratio"])
            )

        pass


class PlotterTimeConvergence(Plotter):

    # convert timestamps to expired seconds since the initial entry
    def _timestamp_to_expired(self, df: DataFrame) -> DataFrame:
        t_init: datetime = df["timestamp"].min()
        l_expired_seconds: list[datetime] = [
            (ts - t_init).total_seconds() for ts in df["timestamp"]
        ]

        df["time-expired"] = l_expired_seconds

        return df

    def plot(self) -> None:

        prop_cycle = plt.rcParams["axes.prop_cycle"]
        colors = prop_cycle.by_key()["color"]

        dic_colors = {
            k: v for k, v in zip(self.data["func_opt"].drop_duplicates(), colors)
        }

        fig, ax = plt.subplots(len(self.circuits), 1, sharex=True)
        for i in range(len(self.circuits)):
            for condition in self.data["func_opt"].drop_duplicates():

                df_metric: DataFrame = self.data[
                    (self.data["metric"] == "c_tcount")
                    & (self.data["func_opt"] == condition)
                    & (self.data["circuit_name"] == self.circuits[i])
                ][["c_tcount", "time-expired"]]

                label = condition.replace("_", " ")
                if len(self.circuits) == 1:
                    ax.plot(
                        df_metric["time-expired"],
                        df_metric["c_tcount"],
                        label=label,
                        color=dic_colors[condition],
                    )
                else:
                    ax[i].plot(
                        df_metric["time-expired"],
                        df_metric["c_tcount"],
                        label=label,
                        color=dic_colors[condition],
                    )
                    ax[i].set_title(self.circuits[i].replace("_", " "))
                    ax[i].set_xlabel(r"Runtime $t_{\text{run}}$ in $[s]$")
                    ax[i].set_ylabel(r"T-gate count $N_{\text{T-gate}}$")

        handles, labels = fig.gca().get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        fig.legend(by_label.values(), by_label.keys())
        plt.show()

    def generate_table(self) -> None:
        df = pd.DataFrame()
        for f in self.path_to_data:
            df_load: DataFrame = self._import_csv(f)
            df_load = self._timestamp_to_expired(df_load)
            df = pd.concat([df, df_load])

        pass


class TableResults(Plotter):
    # convert timestamps to expired seconds since the initial entry
    def _timestamp_to_expired(self, df: DataFrame) -> DataFrame:
        t_init: datetime = df["timestamp"].min()
        l_expired_seconds: list[datetime] = [
            (ts - t_init).total_seconds() for ts in df["timestamp"]
        ]

        df["time-expired"] = l_expired_seconds

        return df

    def generate_table(self) -> None:
        df = pd.DataFrame()
        for f in self.path_to_data:
            df_load: DataFrame = self._import_csv(f)
            df = pd.concat([df, df_load])

        l_func_opts: list[str] = df["func_opt"].unique().tolist()
        l_circuit: list[str] = df["circuit_name"].unique().tolist()

        l_dic_tmp = []

        for circuit_name in l_circuit:
            dic_tmp: dict = {}
            df_tmp = df[
                (df["circuit_name"] == circuit_name) & (df["metric"] == "c_tcount")
            ]
            dic_tmp["circuit_name"] = circuit_name

            for func_opt in l_func_opts:
                df_func = df_tmp[(df_tmp["func_opt"] == func_opt)]
                dic_tmp[func_opt] = df_func["c_tcount"].tolist()[0]
                dic_tmp["q_num_qubits"] = df_func["q_num_qubits"].tolist()[0]
                dic_tmp["c_gates"] = df_func["c_gates"].tolist()[0]

            l_dic_tmp.append(dic_tmp)

        df_filtered = pd.DataFrame(l_dic_tmp)
        df_filtered = df_filtered.sort_values(["q_num_qubits", "c_gates"])

        columns = df_filtered.columns.tolist()
        l_new_columns = (
            [i for i in columns if i == "circuit_name"]
            # + [i for i in columns if i == "q_num_qubits"]
            # + [i for i in columns if i == "c_gates"]
            # + [i for i in columns if i == "unoptimized"]
            # + [i for i in columns if re.match(r"^pyzx", i)]
            # + [i for i in columns if re.match(r"^dfs$", i)]
            # + [i for i in columns if re.match(r"^dfs_+", i)]
            + [i for i in columns if re.match(r"^idfs_+", i)]
        )

        df_filtered = df_filtered[l_new_columns]

        # unicode arrows
        up = "\U00002191"
        dup = "\U000021C8"
        down = "\U00002193"
        ddown = "\U000021CA"
        right = "\U00002192"

        # for func_opt in l_func_opts:
        #    l_full_reduce = df_filtered["pyzx_full_reduce"].tolist()
        #    l_func = df_filtered[func_opt].tolist()
        #    for full, func in zip(l_full_reduce, l_func):
        #        if full == func:
        #            print(right)
        #        elif func == 0 and full > 0:
        #            print(dup)
        #        elif (full - func) / full * 100 >= 20:
        #            print(dup)
        #        elif (full - func) / full * 100 < 20 and (
        #            full - func
        #        ) / full * 100 >= 0:
        #            print(up)
        #        elif (full - func) / full * 100 > -20 and (
        #            full - func
        #        ) / full * 100 <= 0:
        #            print(down)
        #        elif (full - func) / full * 100 <= -20:
        #            print(ddown)
        #        else:
        #            "error"

        df_filtered.to_csv("./tcount.csv", index=False)
        df_filtered.to_latex("./tcount.tex", index=False)


table_generator: TableResults = TableResults("./results.csv")
table_generator.generate_table()
