import qiskit as qk
from qiskit import QuantumCircuit
from qiskit.transpiler import PassManager
from qiskit.transpiler.passes import Unroll3qOrMore
import dill
import pyzx as zx
from pyzx.simplify import (
    to_graph_like,
    is_graph_like,
    Fraction,
    FractionLike,
    simp,
    MatchObject,
    RewriteOutputType,
    Stats,
)
from pyzx.circuit import Circuit
from pyzx.graph.base import BaseGraph, VT, ET
from pyzx import rules, tcount, extract_circuit

from zx_dfs.pruning import PruneColourCycle

from .dfs import DFS


from typing import Tuple, Callable, List, Any, Optional, Union, Type, TypeVar
from functools import partial
from copy import deepcopy
from dataclasses import dataclass, make_dataclass, fields
from itertools import chain
import inspect
import datetime as DT
from datetime import datetime, timedelta
from queue import LifoQueue
from metric import MetricTcount, MetricEdge

from local_dfs import LocalDFS
from local_idfs import LocalIDFS


@dataclass
class Result:
    metric: str
    graph: BaseGraph
    circuit: Circuit
    qc: QuantumCircuit
    node: int | None
    depth: int
    l_path: list[int] | None
    l_rule_sequence: list[str] | None
    l_bundled_rules: list[Callable] | None
    expired_time: float | None


@dataclass
class AllResults:
    metric: Result | None = None
    n_nodes: int | None = 0
    n_leafs: int | None = 0
    max_depth: int | None = 0
    tree: dict[int, Tuple[None | int, None | str]] | None = None
    optimal: bool = False
    qc_init: QuantumCircuit | None = None
    c_init: Circuit | None = None
    g_init: BaseGraph | None = None


"""
Local Elimination
"""


class LocalElimination:
    def __init__(
        self,
        qc: QuantumCircuit,
        max_duration: int | None,
        circ_name: str,
        func_opt: str,
        # solver=LocalDFS,
        metric: list = [MetricTcount(), MetricEdge()],
        pruning: list = [PruneColourCycle()],
        max_domain_depth: int = 5,
        b_simp_one: bool = False,
        b_dump_results: bool = True,
    ) -> None:
        self.qc: QuantumCircuit = self._unroll_qc_2q(qc)
        self.g_init, self.c_init = self._import_circuit(self.qc)
        self.metric: list = metric
        # only one metric is currently supported
        # assert len(self.metric) == 1
        self.pruning: list = pruning
        # self.solver = solver
        self.circ_name = circ_name
        self.func_opt = func_opt
        self.max_duration = max_duration
        self.max_domain_depth = max_domain_depth
        self.b_simp_one = b_simp_one
        self.b_dump_results: bool = b_dump_results

    def _import_circuit(self, qc: QuantumCircuit) -> Tuple[BaseGraph[VT, ET], Circuit]:
        c: Circuit = Circuit.from_qasm(qk.qasm2.dumps(qc))
        g: BaseGraph[VT, ET] = c.to_basic_gates().to_graph()
        to_graph_like(g)
        return g, c

    def _unroll_qc_2q(self, qc: QuantumCircuit) -> QuantumCircuit:
        pm: PassManager = PassManager([Unroll3qOrMore()])
        qc_unrolled: QuantumCircuit = pm.run(qc)
        return qc_unrolled

    def _get_all_tgates(self, graph: BaseGraph) -> list[int]:

        dic_phases: dict[int, FractionLike] = graph.phases()

        s_tphases: set[FractionLike] = {
            s for s in dic_phases.values() if s.denominator > 2
        }
        l_tgates: list[int] = [
            index for index, phase in dic_phases.items() if phase in s_tphases
        ]

        t_count: int = tcount(graph)

        assert t_count == len(l_tgates)

        return l_tgates

    def _get_domain(
        self, graph: BaseGraph, vertix: int, max_depth: int = 3
    ) -> list[int]:
        l_neighbours: list[int] = [vertix]

        depth: int = 1
        while depth <= max_depth:

            l_depth: list[int] = list(
                chain(*[list(graph.neighbors(index)) for index in l_neighbours])
            )

            l_neighbours += l_depth
            l_neighbours = list(set(l_neighbours))

            depth += 1

        return l_neighbours

    def _is_in_domain(
        self, v: int | Tuple[int], l_domain: list[int], graph: BaseGraph
    ) -> bool:

        current_ids: set[int] = set(graph.vertex_set())
        b_inside_window: bool = False

        if isinstance(v, Tuple):
            for vertix in v:
                if isinstance(vertix, List):
                    continue
                elif vertix in current_ids:
                    b_inside_window = vertix in set(l_domain)
                    if not b_inside_window:
                        return b_inside_window

        elif isinstance(v, int):
            vertix = v
            b_inside_window = vertix in set(l_domain)

            if not b_inside_window:
                return b_inside_window

        return b_inside_window

    def _is_in_graph(self, vertix: int, graph: BaseGraph[VT, ET]) -> bool:
        s_vertices: set[int] = graph.vertex_set()
        return vertix in s_vertices

    def _grow_domain(
        self, vertix: int, max_depth: int, graph=BaseGraph[VT, ET]
    ) -> list[int]:
        l_domain: list[int] = self._get_domain(
            graph=graph, vertix=vertix, max_depth=max_depth
        )
        return l_domain

    def _bundle_rules(
        self,
        g: BaseGraph[VT, ET],
        vertix_function: Callable,
        l_rules: list[Callable],
        b_simp_one: bool = True,
        name: str = "bundled_rules",
        quiet: bool = True,
        match=None,
        rewrite=None,
    ) -> bool:
        b_match: bool = False
        cntr: int = 0
        for rule_primitive in l_rules:
            if rule_primitive is not None:
                dic_rule_arguments: dict[str, Any] = inspect.getmembers(rule_primitive)[
                    -1
                ][-1]
                dic_rule_arguments["matchf"] = vertix_function
                dic_rule_arguments["quiet"] = quiet
                if "vertix_function" in dic_rule_arguments:
                    continue

                if b_simp_one:
                    b_match: bool = rule_primitive(g, **dic_rule_arguments)
                else:
                    cntr_match: int = simp(g, **dic_rule_arguments)
                    if cntr_match > 0:
                        b_match: bool = True
                if b_match:
                    cntr += 1
        if cntr > 1:
            b_match = True

        return b_match

    def _prepare_bundled_rules(
        self,
        l_rules: list[Callable],
        vertix_function: Callable,
        b_simp_one: bool = True,
    ):
        return partial(
            self._bundle_rules,
            vertix_function=vertix_function,
            l_rules=l_rules,
            b_simp_one=b_simp_one,
            name="bundled_rules",
            match=None,
        )

    def run(
        self,
    ):
        l_tgates: list[int] = self._get_all_tgates(self.g_init)
        max_domain: int = 0
        best_result = self._setup_results()
        best_result = self._finalize_results(best_result)
        self.start_time = datetime.now()

        if len(l_tgates) > 0:

            graph: BaseGraph = deepcopy(self.g_init)
            domain_depth: int = 1
            l_bundled_rules: list[Callable] = []

            while domain_depth <= self.max_domain_depth:

                for t_gate in self._get_all_tgates(graph):
                    graph_domain: BaseGraph = deepcopy(graph)

                    if not self._is_in_graph(t_gate, graph_domain):
                        continue

                    phase: Fraction = graph_domain.phase(t_gate)

                    if phase.denominator != 4:
                        continue

                    l_domain: list[int] = self._grow_domain(t_gate, domain_depth, graph)

                    if len(l_domain) > max_domain:
                        max_domain = len(l_domain)

                    filter_function: Callable = partial(
                        self._is_in_domain, l_domain=l_domain, graph=graph_domain
                    )

                    self.current_time = datetime.now()
                    self.expired_time = (
                        self.current_time - self.start_time
                    ).total_seconds()

                    result = LocalDFS(
                        g=graph_domain,
                        vertix_id=t_gate,
                        vertices_function=filter_function,
                        l_bundled_rules=[
                            self._prepare_bundled_rules(
                                l_rules=rules,
                                vertix_function=filter_function,
                                b_simp_one=self.b_simp_one,
                            )
                            for rules in l_bundled_rules
                            if len(rules) > 0
                        ],
                        metric=self.metric,
                        pruning=self.pruning,
                        b_simp_one=self.b_simp_one,
                    ).run()

                    for m in self.metric:

                        metric_result = getattr(result, m.__str__())
                        old_result = getattr(best_result, m.__str__())

                        if m.check(metric_result.graph, old_result.graph):
                            graph: BaseGraph = deepcopy(metric_result.graph)
                            metric_result.expired_time = (
                                self.expired_time + metric_result.expired_time
                            )
                            setattr(best_result, m.__str__(), metric_result)

                            l_bundled_rules.append(metric_result.l_bundled_rules)

                domain_depth += 1

        if self.b_dump_results:
            self._dump_results(best_result)

        return best_result

    def _finalize_results(self, results) -> AllResults:
        for var in fields(results):
            if var.name in [m.__str__() for m in self.metric]:
                old_metric = getattr(results, var.name)

                if not is_graph_like(old_metric.graph):
                    to_graph_like(old_metric.graph)

                old_metric.circuit = extract_circuit(deepcopy(old_metric.graph))
                old_metric.qc = QuantumCircuit.from_qasm_str(
                    old_metric.circuit.to_qasm()
                )

                setattr(results, var.name, old_metric)
        return results

    def _setup_results(self) -> AllResults:
        AllResults = make_dataclass(
            "AllResults",
            [(m.__str__(), type(Result)) for m in self.metric]
            + [
                ("n_nodes", int),
                ("n_leafs", int),
                ("max_depth", int),
                ("node", None | int),
                ("tree", None | dict),
                ("qc_init", None | QuantumCircuit),
                ("c_init", None | Circuit),
                ("g_init", None | BaseGraph[VT, ET]),
            ],
        )

        # initialize result data class
        for var in fields(AllResults):
            if (
                var.name == "n_nodes"
                or var.name == "n_leafs"
                or var.name == "max_depth"
            ):
                setattr(AllResults, var.name, 0)
            elif var.name == "tree":
                setattr(AllResults, var.name, None)
            elif var.name == "qc_init":
                setattr(AllResults, var.name, deepcopy(self.qc))
            elif var.name == "c_init":
                setattr(
                    AllResults,
                    var.name,
                    deepcopy(self.c_init),
                )
            elif var.name == "g_init":
                setattr(AllResults, var.name, deepcopy(self.g_init))
            else:
                setattr(
                    AllResults,
                    var.name,
                    Result(
                        var.name,
                        deepcopy(self.g_init),
                        None,
                        None,
                        0,
                        0,
                        None,
                        None,
                        None,
                        0.0,
                    ),
                )
        AllResults.optimal = False

        return AllResults

    def _dump_results(self, obj, prefix: None | int = None) -> None:
        # save our results as a pickled object
        if prefix is not None:
            time_prefix: str = str(prefix)
        else:
            time_prefix: str = DT.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

        if self.func_opt is None:
            func_opt: str = f"{self.__class__.__name__}"
        else:
            func_opt = self.func_opt

        if isinstance(self.circ_name, str):
            dill.dump(
                obj,
                open(f"_{time_prefix}_{self.circ_name}_{func_opt}.pkl", "wb"),
                recurse=True,
            )
        else:
            dill.dump(
                obj,
                open(f"_{time_prefix}_{self.circ_name}_{func_opt}.pkl", "wb"),
                recurse=True,
            )
