import qiskit as qk
from qiskit import QuantumCircuit
from qiskit.transpiler import PassManager
from qiskit.transpiler.passes import Unroll3qOrMore

import pyzx as zx
from pyzx.simplify import (
    to_graph_like,
    is_graph_like,
    Fraction,
    simp,
    MatchObject,
    RewriteOutputType,
    Stats,
)
from pyzx.circuit import Circuit
from pyzx.graph.base import BaseGraph, VT, ET
from pyzx import rules, tcount, extract_circuit

from dfs import DFS


from typing import Tuple, Callable, List, Any, Optional, Union
from functools import partial
from copy import deepcopy
from dataclasses import dataclass, make_dataclass, fields
from itertools import chain
import inspect
import datetime as DT
from datetime import datetime, timedelta
from queue import LifoQueue

from local_dfs import LocalDFS

"""
Modified IDDFS algorithm that only works inside given domains
"""


class LocalIDFS(LocalDFS):
    def run_simp_packed(self):
        max_level: int = 1

        b_iterative_deepening: bool = True

        # compute maximum allowed runtime of dfs
        if self.b_max_duration:
            self.end_time: datetime = datetime.now() + timedelta(
                seconds=self.max_duration
            )

        TMPResults = self._setup_results()
        while b_iterative_deepening:
            dic_root: dict = self._setup()
            if not is_graph_like(dic_root["g"]):
                to_graph_like(dic_root["g"])

            TMPResults = deepcopy(self.AllResults)
            TMPTree = deepcopy(self.tree)

            self.tree = self._setup_tree()
            self.AllResults = self._setup_results()

            s: LifoQueue = LifoQueue()  # LIFO is stack for dfs
            s.put(deepcopy(dic_root))

            b_iterative_deepening = False

            while s.qsize() > 0:

                # terminate dfs if maximum allowed runtime was reached
                if self.b_max_duration:
                    self.current_time: datetime = datetime.now()
                    if self.current_time > self.end_time:
                        self.AllResults = TMPResults
                        self.tree = TMPTree
                        break

                dic_stack = s.get()

                for rule in self.l_zx_rules:

                    # get a dictionary with all parameters of currentliy visited node
                    dic_child, node = self._get_current_node(rule, dic_stack)

                    # get depth of current of current node
                    if self.b_trace_depth:
                        current_depth = dic_child["cntr_max_depth"]
                    else:
                        current_depth = None

                    ## terminate tree based tree or rule pruning conditions
                    # if self._check_pruning_conditions_tree_rule(
                    #    dic_stack, dic_child, rule
                    # ):
                    #    continue

                    # apply rule to graph
                    child_graph = deepcopy(dic_stack["g"])

                    if not is_graph_like(child_graph):
                        to_graph_like(child_graph)

                    cntr = rule(child_graph, quiet=self.b_quiet)

                    # terminate tree based on circuit pruning conditions
                    # if self._check_pruning_conditions_circuit(
                    #    deepcopy(child_graph), node, current_depth
                    # ):
                    #    continue

                    # explore node propterties at every node of the tree
                    # if self.b_check_every_node:
                    #    try:
                    #        self._explore_node(
                    #            deepcopy(child_graph), node, current_depth
                    #        )
                    #    except Exception:
                    #        pass

                    # only deepen to next level if we changed the graph at the current max level
                    if (dic_child["cntr_max_depth"] == max_level) and (cntr > 0):
                        b_iterative_deepening = True

                        # not pushing to stack here so we can terminate

                    # push to stack if rule can still be applied
                    elif cntr > 0:
                        dic_child["g"] = deepcopy(child_graph)
                        s.put(dic_child)

                    # explore node properties at leaf node
                    else:
                        if self.b_trace_leafs:
                            self.n_leafs += 1
                        try:
                            self._explore_node(
                                deepcopy(child_graph), node, current_depth
                            )
                        except Exception:
                            pass

            # update if solution globally optimal or not
            if self.b_max_duration:
                if self.end_time < self.current_time:
                    self.AllResults.optimal = False
                else:
                    if not self.b_kill_optimal:
                        self.AllResults.optimal = True
            else:
                if not self.b_kill_optimal:
                    self.AllResults.optimal = True
                else:
                    self.AllResults.optimal = False

            # terminate dfs if maximum allowed runtime was reached
            if self.b_max_duration:
                self.current_time: datetime = datetime.now()
                if self.current_time > self.end_time:
                    self.AllResults = TMPResults
                    self.tree = TMPTree
                    break

            # only deepen if tree is not finished
            if b_iterative_deepening:
                max_level += 1

        # update and save results from dfs
        self._save_results()

    def run_simp_one(self):
        max_level: int = 1

        b_iterative_deepening: bool = True

        # compute maximum allowed runtime of dfs
        if self.b_max_duration:
            self.end_time: datetime = datetime.now() + timedelta(
                seconds=self.max_duration
            )

        TMPResults = self._setup_results()
        while b_iterative_deepening:
            dic_root: dict = self._setup()
            if not is_graph_like(dic_root["g"]):
                to_graph_like(dic_root["g"])

            TMPResults = deepcopy(self.AllResults)
            TMPTree = deepcopy(self.tree)

            self.tree = self._setup_tree()
            self.AllResults = self._setup_results()

            s: LifoQueue = LifoQueue()  # LIFO is stack for dfs
            s.put(deepcopy(dic_root))

            b_iterative_deepening = False

            while s.qsize() > 0:

                # terminate dfs if maximum allowed runtime was reached
                if self.b_max_duration:
                    self.current_time: datetime = datetime.now()
                    if self.current_time > self.end_time:
                        self.AllResults = TMPResults
                        self.tree = TMPTree
                        break

                dic_stack = s.get()

                for rule in self.l_zx_rules:

                    # get a dictionary with all parameters of currentliy visited node
                    dic_child, node = self._get_current_node(rule, dic_stack)

                    # get depth of current of current node
                    if self.b_trace_depth:
                        current_depth = dic_child["cntr_max_depth"]
                    else:
                        current_depth = None

                    ## terminate tree based tree or rule pruning conditions
                    # if self._check_pruning_conditions_tree_rule(
                    #    dic_stack, dic_child, rule
                    # ):
                    #    continue

                    # apply rule to graph
                    child_graph = deepcopy(dic_stack["g"])

                    if not is_graph_like(child_graph):
                        to_graph_like(child_graph)

                    match = inspect.getmembers(rule)[-1][-1]["match"]
                    if match is not None:
                        all_matches = match(child_graph, self.vertices_function)
                        if len(all_matches) == 0:
                            if self.b_trace_leafs:
                                self.n_leafs += 1
                            self._explore_node(
                                deepcopy(child_graph), node, current_depth
                            )

                        for m in all_matches:
                            g = deepcopy(child_graph)
                            cntr: int = rule(
                                g=g,
                                matches_to_select=m,
                            )

                            # push to stack if rule can still be applied
                            dic_child["g"] = deepcopy(g)
                            # only deepen to next level if we changed the graph at the current max level
                            if dic_child["cntr_max_depth"] == max_level:
                                b_iterative_deepening = True
                                self._explore_node(
                                    deepcopy(child_graph), node, current_depth
                                )
                                if self.b_trace_leafs:
                                    self.n_leafs += 1
                            else:
                                s.put(deepcopy(dic_child))
                    else:
                        g = deepcopy(child_graph)
                        cntr: int = rule(
                            g=g,
                        )

                        # only deepen to next level if we changed the graph at the current max level
                        if (dic_child["cntr_max_depth"] == max_level) and (cntr > 0):
                            b_iterative_deepening = True

                        # push to stack if rule can still be applied
                        elif cntr > 0:
                            dic_child["g"] = deepcopy(child_graph)
                            s.put(deepcopy(dic_child))

                        else:
                            if self.b_trace_leafs:
                                self.n_leafs += 1
                            self._explore_node(
                                deepcopy(child_graph), node, current_depth
                            )

                    # terminate tree based on circuit pruning conditions
                    # if self._check_pruning_conditions_circuit(
                    #    deepcopy(child_graph), node, current_depth
                    # ):
                    #    continue

                    # explore node propterties at every node of the tree
                    # if self.b_check_every_node:
                    #    try:
                    #        self._explore_node(
                    #            deepcopy(child_graph), node, current_depth
                    #        )
                    #    except Exception:
                    #        pass

                    # not pushing to stack here so we can terminate

            # update if solution globally optimal or not
            if self.b_max_duration:
                if self.end_time < self.current_time:
                    self.AllResults.optimal = False
                else:
                    if not self.b_kill_optimal:
                        self.AllResults.optimal = True
            else:
                if not self.b_kill_optimal:
                    self.AllResults.optimal = True
                else:
                    self.AllResults.optimal = False

            # terminate dfs if maximum allowed runtime was reached
            if self.b_max_duration:
                self.current_time: datetime = datetime.now()
                if self.current_time > self.end_time:
                    self.AllResults = TMPResults
                    self.tree = TMPTree
                    break

            # only deepen if tree is not finished
            if b_iterative_deepening:
                max_level += 1

        # update and save results from dfs
        self._save_results()
