import os  # allow imports
import sys
import dill

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for module_folder in ["algorithms", "benchmark", "zx_dfs", "statistics", "bin"]:
    sys.path.append(f"{parent_dir}/{module_folder}")
else:
    sys.path.append(f"{parent_dir}")

import argparse
from glob import glob

from typing import Callable

from benchmark.run_benchmark import benchmark_run
from algorithms.pyzx import (
    unoptimized,
    pyzx_full_reduce,
    pyzx_teleport_reduce,
    pyzx_scalar_reduce,
)
import qiskit as qk
import pyzx as zx


def cli_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--files", nargs="+", type=str)
    parser.add_argument("-c", "--circuit", action="store_true")
    parser.add_argument("-g", "--graph", action="store_true")
    parser.add_argument("-ss", "--sequence", action="store_true")

    return parser


def is_identical_graph(g_org, g_opt) -> bool:
    b_same = True

    n_init = g_org.edge_set()
    n_opt = g_opt.edge_set()

    for v in g_opt.vertices():
        e_init = [e for e in n_init if v in e]
        e_opt = [e for e in n_opt if v in e]

        if e_init == e_opt:
            for e in e_init:
                if g_org.edge_type(e) != g_opt.edge_type(e):
                    b_same = False
                    return b_same
            continue
        else:
            b_same = False
            return b_same

    return b_same


if __name__ == "__main__":

    parser = cli_parser()
    args = parser.parse_args()
    if args.files:
        l_path = args.files
    else:
        # check if smaller benchmark was selected; meaning less circuits
        l_path = glob(f"*.pkl")

    for circ_pkl in l_path:
        with open(circ_pkl, "rb") as f:
            all_results = dill.load(f)

            try:

                if args.sequence:
                    print(all_results.g_tcount.l_rule_sequence)
                if args.graph:
                    b_test = zx.compare_tensors(
                        all_results.g_init, all_results.g_tcount.graph
                    )
                    print(
                        f"Graph-compare:\t{b_test}\t\t{circ_pkl}\t\t{all_results.g_tcount.circuit.tcount()}"
                    )
                if args.circuit:
                    b_test = zx.compare_tensors(
                        all_results.c_init, all_results.g_tcount.circuit
                    )

                    if not b_test:
                        b_test = is_identical_graph(
                            all_results.g_init, all_results.g_tcount.graph
                        )
                    print(
                        f"Circuit-compare:\t{b_test}\t\t{circ_pkl}\t\t{all_results.g_tcount.circuit.tcount()}"
                    )

            except Exception as error:
                print(error)
                continue
