import qiskit as qk
from qiskit import QuantumCircuit
from qiskit.transpiler import PassManager
from qiskit.transpiler.passes import Unroll3qOrMore

import pyzx as zx
from pyzx.simplify import (
    to_graph_like,
    is_graph_like,
    Fraction,
    simp,
    MatchObject,
    RewriteOutputType,
    Stats,
)
from pyzx.circuit import Circuit
from pyzx.graph.base import BaseGraph, VT, ET
from pyzx import rules, tcount, extract_circuit

from dfs import DFS


from typing import Tuple, Callable, List, Any, Optional, Union, override
from functools import partial
from copy import deepcopy
from dataclasses import dataclass, make_dataclass, fields
from itertools import chain
import inspect
import datetime as DT
from datetime import datetime, timedelta
from queue import LifoQueue

from metric import Metric, MetricTcount
from pruning import PruneColourCycle


@dataclass
class Result:
    metric: str
    graph: BaseGraph
    circuit: Circuit
    qc: QuantumCircuit
    node: int | None
    depth: int
    l_path: list[int] | None
    l_rule_sequence: list[str] | None
    l_bundled_rules: list[Callable] | None
    expired_time: float | None


@dataclass
class AllResults:
    metric: Result | None = None
    n_nodes: int | None = 0
    n_leafs: int | None = 0
    max_depth: int | None = 0
    tree: dict[int, Tuple[None | int, None | str]] | None = None
    optimal: bool = False
    qc_init: QuantumCircuit | None = None
    c_init: Circuit | None = None
    g_init: BaseGraph | None = None


""" Modified DFS algorithm that only works inside given domains
"""


class LocalDFS(DFS):
    def __init__(
        self,
        g: BaseGraph[VT, ET],
        vertix_id: int,
        vertices_function: Callable,
        l_bundled_rules: list[list[Callable]] | None = None,
        b_simp_one: bool = False,
        metric=[MetricTcount()],
        pruning: list = [PruneColourCycle()],
        b_quiet: bool = True,
        b_check_every_node: bool = True,
        mult: int = 2,
        max_depth: int = 15,
        circ_name: str | None = None,
        b_dump_results: bool = False,
        func_opt: str | None = None,
        b_max_duration: bool = True,
        max_duration: int = 5,  # maximum duration in seconds
        b_snapshot: bool = True,  # for csv files
        b_snapshot_pkl: bool = False,  # for pickles objects with graphs
        snapshot_frequency: int = 1,  # snapshot every n seconds
        b_full_start: bool = False,
    ):
        """
        Init because now a Graph instead of a quantum circuit is given to the optimizer
        """

        # init arguments for needed for compiler pass
        self.my_init_arguments = self._get_init_arguments()

        if not is_graph_like(g):
            to_graph_like(g)

        self.g_init: BaseGraph[VT, ET] = g
        self.c_init: Circuit = extract_circuit(g.copy()).to_basic_gates()
        self.qc_init: QuantumCircuit = self._unroll_qc_2q(
            QuantumCircuit.from_qasm_str(self.c_init.to_qasm())
        )

        """
        Local elimination specific
        """

        # target vertix that is used as the center of the domain
        self.vertix_id: int = vertix_id

        # additional filter function that rejects matches outside the local domain
        self.vertices_function: Callable = vertices_function

        # list of rule sequences that were able to eliminate previous target vertices
        self.l_bundled_rules: list[list[Callable]] | None = l_bundled_rules

        # each match of a rule results into a child node
        self.b_simp_one: bool = b_simp_one

        # two strategies
        # 1. first optimize T-gates until maximum domain depth was reached, then use result to optimize next metric
        # 2. order when a solution is accepted
        # to be implemented for all solvers

        # the metric to optimize for
        self.metric: list = metric
        self.pruning: list = pruning

        # data class that stores our results
        self.AllResults = self._setup_results()

        # circ name is required for filename if we dump our result object
        self.circ_name: str | None = circ_name

        # whether to dump result object
        self.b_dump_results: bool = b_dump_results
        # identifier for pickle object
        self.func_opt: str | None = f"{func_opt}-{vertix_id}"

        """
        tracing and dfs behaviour
        """
        # no output if rules change graph
        self.b_quiet: bool = b_quiet

        # trace dfs forced because of rule bundling
        self.b_trace_leafs: bool = True
        self.b_trace_nodes: bool = True
        self.b_trace_depth: bool = True

        # dfs depth
        self.depth: int = 0

        # check whether to check every node
        self.b_check_every_node: bool = b_check_every_node

        # snapshot run every n seconds
        self.b_snapshot: bool = b_snapshot
        self.snapshot_frequency: int = snapshot_frequency
        self.b_snapshot_pkl: bool = b_snapshot_pkl

        """
        pruning rules
        """
        self.mult: int = mult

        self.max_depth: int = max_depth

        # limit the time a dfs instance is allowed to run
        self.b_max_duration: bool = b_max_duration
        self.max_duration: int = max_duration  # seconds

        # use default rule setup rules
        self.l_zx_rules: list[Callable] = self._zx_rules()

        # set up dictionary for path setup
        self.tree: dict = self._setup_tree()

        self.start_time = DT.datetime.now()

        self.b_full_start: bool = b_full_start

    def _is_in_graph(self, vertix: int, graph: BaseGraph[VT, ET]) -> bool:
        s_vertices: set[int] = graph.vertex_set()
        return vertix in s_vertices

    def _zx_rules(self) -> list[Callable]:
        l_rules: list[Callable] = []
        l_match_rewrite: list[Tuple[Callable, Callable]] = self._rules_only()

        if self.l_bundled_rules is not None:
            for rule_bundle in self.l_bundled_rules:
                if rule_bundle is not None:
                    l_rules.append(rule_bundle)

        if self.b_simp_one:
            for match, rule in l_match_rewrite:
                l_rules.append(
                    partial(
                        self._simp_one, name=match.__name__, match=match, rewrite=rule
                    )
                )

        else:
            for match, rule in l_match_rewrite:
                l_rules.append(
                    partial(
                        simp,
                        name=match.__name__,
                        match=match,
                        rewrite=rule,
                        matchf=self.vertices_function,
                    )
                )

        return l_rules

    def _simp_one(
        self,
        g: BaseGraph[VT, ET],
        name: str,
        match: Callable[..., List[MatchObject]],
        rewrite: Callable[
            [BaseGraph[VT, ET], List[MatchObject]], RewriteOutputType[ET, VT]
        ],
        matchf: Optional[Union[Callable[[ET], bool], Callable[[VT], bool]]] = None,
        quiet: bool = True,
        stats: Optional[Stats] = None,
        matches_to_select: MatchObject | None = None,
    ) -> int:

        new_matches: int = 0

        m: list[MatchObject] = []
        if matches_to_select is not None:
            m: List[MatchObject] = [matches_to_select]
        elif matchf is not None:
            m: List[MatchObject] = match(g, matchf)

        if len(m) > 0:

            new_matches = 1
            # if matches_to_select is not None:
            #    l_selected_matches: list[MatchObject] = [
            #        i for i in matches_to_select if i in m
            #    ]
            # else:
            l_selected_matches: list[MatchObject] = m

            if not quiet:
                print("{}: ".format(name), end="")
                print(len(m), end="")

            etab, rem_verts, rem_edges, check_isolated_vertices = rewrite(
                g, l_selected_matches
            )

            g.add_edge_table(etab)
            g.remove_edges(rem_edges)
            g.remove_vertices(rem_verts)

            if check_isolated_vertices:
                g.remove_isolated_vertices()

            if not quiet:
                print(". ", end="")

            if stats is not None:
                stats.count_rewrites(name, len(l_selected_matches))
                print(stats)

        return new_matches

    def _rules_only(
        self,
    ) -> list[Tuple[Callable, Callable]]:
        # import rewrting rules from pyzx
        from pyzx.rules import (
            match_ids_parallel,
            match_spider_parallel,
            match_phase_gadgets,
            match_supplementarity,
            match_z_to_z_box_parallel,
            match_w_fusion_parallel,
            match_copy,
            match_pivot_parallel,
            match_pivot_boundary,
            match_pivot_gadget,
            match_lcomp_parallel,
        )

        from pyzx.rules import (
            remove_ids,
            spider,
            merge_phase_gadgets,
            apply_supplementarity,
            z_to_z_box,
            w_fusion,
            apply_copy,
            pivot,
            lcomp,
        )

        # all rules in a list
        l_zx_rules: list[Tuple[Callable, Callable]] = [
            (match_ids_parallel, remove_ids),
            (match_spider_parallel, spider),
            (match_phase_gadgets, merge_phase_gadgets),
            (match_supplementarity, apply_supplementarity),
            # (match_z_to_z_box_parallel, z_to_z_box),
            # (match_w_fusion_parallel, w_fusion),
            (match_copy, apply_copy),
            (match_pivot_parallel, pivot),
            (match_pivot_boundary, pivot),
            (match_pivot_gadget, pivot),
            (match_lcomp_parallel, lcomp),
        ]

        return l_zx_rules

    def run_simp_one(self):
        dic_root: dict = self._setup()

        s: LifoQueue = LifoQueue()  # LIFO is stack for dfs
        s.put(dic_root)

        # compute maximum allowed runtime of dfs
        if self.b_max_duration:
            end_time: datetime = datetime.now() + timedelta(seconds=self.max_duration)

        while s.qsize() > 0:
            dic_stack = s.get()
            if not is_graph_like(dic_stack["g"]):
                to_graph_like(dic_stack["g"])

            # terminate dfs if maximum allowed runtime was reached
            if self.b_max_duration:
                current_time: datetime = datetime.now()
                if current_time > end_time:
                    break

            for rule in self.l_zx_rules:

                # get a dictionary with all parameters of currentliy visited node
                dic_child, node = self._get_current_node(rule, dic_stack)

                # get depth of current of current node
                if self.b_trace_depth:
                    current_depth = dic_child["cntr_max_depth"]
                else:
                    current_depth = None

                child_graph = deepcopy(dic_stack["g"])

                match = inspect.getmembers(rule)[-1][-1]["match"]

                if match is not None:
                    all_matches = match(child_graph, self.vertices_function)

                    if len(all_matches) == 0:
                        l_better = self._check_metric(child_graph, node, current_depth)
                        if self.b_trace_leafs:
                            self.n_leafs += 1

                    for m in all_matches:
                        g = deepcopy(child_graph)
                        cntr: int = rule(
                            g=g,
                            matches_to_select=m,
                        )

                        # push to stack if rule can still be applied
                        dic_child["g"] = g
                        if not self._check_pruning_conditions(dic_child, dic_stack):
                            s.put(dic_child)
                        else:
                            if self.b_trace_leafs:
                                self.n_leafs += 1
                            self._check_metric(
                                child_graph, node, current_depth, self.metric
                            )
                else:
                    g = deepcopy(child_graph)
                    cntr: int = rule(
                        g=g,
                    )

                    # push to stack if rule can still be applied
                    if cntr > 0:
                        dic_child["g"] = deepcopy(child_graph)
                        if not self._check_pruning_conditions(dic_child, dic_stack):
                            s.put(dic_child)
                        else:
                            if self.b_trace_leafs:
                                self.n_leafs += 1
                            self._check_metric(
                                child_graph, node, current_depth, self.metric
                            )

                    else:
                        l_better = self._check_metric(child_graph, node, current_depth)
                        if self.b_trace_leafs:
                            self.n_leafs += 1

        self._finalize_results()
        self._save_results()

        return self.AllResults

    def run_simp_packed(self):
        dic_root: dict = self._setup()

        s: LifoQueue = LifoQueue()  # LIFO is stack for dfs
        s.put(dic_root)

        # compute maximum allowed runtime of dfs
        if self.b_max_duration:
            end_time: datetime = datetime.now() + timedelta(seconds=self.max_duration)

        while s.qsize() > 0:
            dic_stack = s.get()

            # terminate dfs if maximum allowed runtime was reached
            if self.b_max_duration:
                current_time: datetime = datetime.now()
                if current_time > end_time:
                    break

            for rule in self.l_zx_rules:

                # get a dictionary with all parameters of currentliy visited node
                dic_child, node = self._get_current_node(rule, dic_stack)

                # get depth of current of current node
                if self.b_trace_depth:
                    current_depth = dic_child["cntr_max_depth"]
                else:
                    current_depth = None

                # apply rule to graph
                child_graph = deepcopy(dic_stack["g"])

                # explore node propterties at every node of the tree
                if self.b_check_every_node:
                    self._check_metric(child_graph, node, current_depth, self.metric)

                if not is_graph_like(child_graph):
                    to_graph_like(child_graph)

                cntr = rule(child_graph, quiet=self.b_quiet)

                # push to stack if rule can still be applied
                if cntr > 0:
                    dic_child["g"] = deepcopy(child_graph)
                    if not self._check_pruning_conditions(dic_child, dic_stack):
                        s.put(dic_child)

                # explore node properties at leaf node
                else:
                    if self.b_trace_leafs:
                        self.n_leafs += 1

                    self._check_metric(child_graph, node, current_depth, self.metric)

        self._finalize_results()
        self._save_results()

        return self.AllResults

    # main method of the class that performs the optimization
    def run(self) -> AllResults:
        if self.b_simp_one:
            self.run_simp_one()
        else:
            self.run_simp_packed()

        return self.AllResults

    def _get_rule_name_from_partial(self, rule: Callable) -> str:
        return inspect.getmembers(rule)[-1][-1]["name"]

    def _get_current_node(
        self, rule: Callable, dic_stack: dict
    ) -> Tuple[dict, int | None]:

        dic_child = self._node_skeleton()
        dic_child["past_rule"] = self._get_rule_name_from_partial(rule)

        # trace nodes and generate memory
        if self.b_trace_nodes:
            self.n_nodes += 1
            node = self.n_nodes
            dic_child["node"] = node

            # generate memory for path tracing
            parent_node = dic_stack["node"]
            self.tree[node] = (parent_node, rule)
        else:
            node = None

        # update depth
        dic_child = self._update_depth(dic_stack, dic_child)

        return dic_child, node

    def _save_results(self) -> None:
        # save potential trace to results
        if self.b_trace_nodes:
            self.AllResults.n_nodes = self.n_nodes
            self.AllResults.tree = self.tree

            # compute path for best results
            for m in self.metric:
                metric_result = getattr(self.AllResults, m.__str__())

                # get path
                l_path = self._trace_path(metric_result.node)
                metric_result.l_path = l_path

                # get sequence
                metric_result.l_rule_sequence = [
                    self._get_rule_name_from_partial(self.tree[i][-1])
                    for i in l_path
                    if self.tree[i][-1] is not None
                ]

                metric_result.l_bundled_rules = [
                    self.tree[i][1] for i in l_path if self.tree[i][-1] is not None
                ]

                # update object
                setattr(AllResults, m.__str__(), metric_result)

        if self.b_trace_leafs:
            self.AllResults.n_leafs = self.n_leafs
        if self.b_trace_depth:
            self.AllResults.max_depth = self.depth

        # save pickled object to reuse the graphs later on
        # if self.b_dump_results:
        #    self._dump_results()
