import os
import sys
from glob import glob
import dill

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for module_folder in ["algorithms", "benchmark", "zx_dfs", "statistics", "bin"]:
    sys.path.append(f"{parent_dir}/{module_folder}")
else:
    sys.path.append(f"{parent_dir}")

from qiskit import QuantumCircuit
import qiskit as qk

from benchmark.run_benchmark import get_circuit_statistics, get_circuit_name

from zx_dfs.dfs import AllResults, Result
from algorithms.dfs import *
from benchmark.statistics import get_circuit_name, get_circuit_statistics

import pyzx as zx


pickle_dir: str = "./.pkl"
l_test_circuits: list[str] = ["xor5d1", "3_17tc"]
l_dfs_type: list[str] = [
    "unoptimized",
    "dfs",
    "dfs_strict_convergence",
    "dfs_no_colour_cycle_limit_colour_change",
    "dfs_limit_colour_change",
    "dfs_no_colour_cycle_max_depth",
]


def load_results(fname: str) -> AllResults:
    with open(fname, "rb") as f:
        result: AllResults = dill.load(f)
    return result


def extract_path(result: AllResults, metric) -> list[str] | None:
    result_obj: Result = getattr(result, metric)
    return result_obj.l_rule_sequence


def apply_rules(circuit_path: str, l_rule_sequence: list[str] | None):

    # import rewrting rules from pyzx
    from pyzx.simplify import (
        pivot_simp,
        pivot_gadget_simp,
        pivot_boundary_simp,
        lcomp_simp,
        bialg_simp,
        spider_simp,
        id_simp,
        gadget_simp,
        supplementarity_simp,
        copy_simp,
        phase_free_simp,
        interior_clifford_simp,
        clifford_simp,
    )

    # use custom rule that return whether a graph was changed or not
    from colour_change import (
        to_gh,
        to_rg,
    )

    from zx_dfs.rewrite_rules import w_fusion_simp, z_to_z_box_simp

    # all rules in a list
    l_rules: list[Callable] = [
        pivot_simp,
        pivot_gadget_simp,
        pivot_boundary_simp,
        lcomp_simp,
        spider_simp,
        bialg_simp,
        id_simp,
        gadget_simp,
        supplementarity_simp,
        copy_simp,
        phase_free_simp,
        w_fusion_simp,
        z_to_z_box_simp,
        # interior_clifford_simp,
        # clifford_simp,
        to_gh,
        # to_rg,
    ]

    qc_init: QuantumCircuit = QuantumCircuit.from_qasm_file(circuit_path)
    c_init = zx.Circuit.from_qasm(qk.qasm2.dumps(qc_init))
    g_init = c_init.to_graph()
    g = g_init.copy()
    print(zx.tcount(g))
    print(g)

    # apply rules to graph
    for rule in l_rule_sequence:
        for callable_rule in l_rules:
            if rule == callable_rule.__name__:
                callable_rule(g, quiet=True)
                break
            else:
                continue

    return g.copy()


test = load_results(
    "/home/tobias/zx_dfs/tests/.pkl/_2024-07-31-00-27-04_3_17tc_dfs.pkl"
)

l_sequence: list[str] | None = extract_path(test, "c_tcount")

g = apply_rules("circuits/3_17tc.qasm", l_sequence)
print(zx.tcount(g))

print(zx.is_unitary(g))
print(zx.compare_tensors(g, test.c_tcount.graph))

qc_init: QuantumCircuit = QuantumCircuit.from_qasm_file("circuits/3_17tc.qasm")
c_init = zx.Circuit.from_qasm(qk.qasm2.dumps(qc_init))
g_init = c_init.to_graph()

print(zx.compare_tensors(g, g_init))
print(zx.compare_tensors(g_init, test.c_tcount.graph))

g_full = c_init.to_graph()
zx.simplify.full_reduce(g_full)
print(zx.tcount(g_full))

print(zx.compare_tensors(g_full, g_init))
print(zx.compare_tensors(g_full, test.c_tcount.graph))

print(test.c_tcount.qc)
qc_full = QuantumCircuit.from_qasm_str(zx.extract_circuit(g_full).to_qasm())
print(qc_full)
