"""
DFS
"""

import qiskit as qk
from qiskit import QuantumCircuit

import pyzx as zx
from pyzx.graph.base import BaseGraph, VT, ET
from pyzx.circuit import Circuit
from pyzx import extract_circuit

from typing import Tuple, Callable, Any, Self

import queue
from queue import LifoQueue

from copy import deepcopy
from copy import copy

from multiprocessing import Process
from multiprocessing import cpu_count
from multiprocessing.managers import BaseManager

from full_analysis import full_analysis

from .dfs import get_circuit_statistics, zx_rules


def dfs(
    qc: QuantumCircuit,
    quiet: bool = True,
    metric: str = "c_tcount",
    b_count_leafs: bool = False,
) -> (
    Tuple[QuantumCircuit, Circuit, BaseGraph[VT, ET]]
    | Tuple[QuantumCircuit, Circuit, BaseGraph[VT, ET], int]
):

    # import rules from pyzx and our custom rules
    l_rules: list[Callable] = zx_rules()

    c: Circuit = Circuit.from_qasm(qk.qasm2.dumps(qc))
    g: BaseGraph[VT, ET] = c.to_graph()

    n_cntr_colour: int = limit_gh_rg(qc, mult=2)
    depth = 0
    s: LifoQueue[Tuple[BaseGraph[VT, ET], int]] = LifoQueue()  # LIFO is stack for dfs
    # s.put((deepcopy(g), n_cntr_colour, depth))
    s.put((deepcopy(g), n_cntr_colour))

    qc_best: QuantumCircuit = qc
    c_best: Circuit = c
    g_best: BaseGraph[VT, ET] = g

    if b_count_leafs:
        cntr_leafs: int = 0

    while s.qsize() > 0:
        stack_graph, cntr_colour_change = s.get()

        # print(s.qsize())
        # if depth > 5:
        #    continue

        for rule in l_rules:
            # mydepth = deepcopy(depth)
            # mydepth += 1

            # child_graph: BaseGraph[VT, ET] = deepcopy(stack_graph)
            child_graph: BaseGraph[VT, ET] = stack_graph.copy()

            if (
                rule.__name__ == "to_gh" or rule.__name__ == "to_rg"
            ) and cntr_colour_change > 0:
                cntr_colour_change -= 1
            elif (
                rule.__name__ == "to_gh" or rule.__name__ == "to_rg"
            ) and cntr_colour_change == 0:
                continue

            cntr: int = rule(child_graph, quiet=quiet)

            try:
                c_opt: Circuit = extract_circuit(child_graph.copy())
                qc_opt: QuantumCircuit = QuantumCircuit.from_qasm_str(c_opt.to_qasm())

                if check_metric(
                    qc_opt,
                    c_opt,
                    child_graph,
                    qc_best,
                    c_best,
                    g_best,
                    metric=metric,
                ):
                    qc_best = qc_opt
                    c_best = c_opt
                    g_best = child_graph

            except:
                pass
                # if b_count_leafs:
                #    cntr_leafs += 1
                # continue

            if cntr > 0:
                # s.put((child_graph, cntr_colour_change, mydepth))
                s.put((child_graph, cntr_colour_change))
            else:
                if b_count_leafs:
                    cntr_leafs += 1

            # else:
            #    if b_count_leafs:
            #        cntr_leafs += 1
            #    try:
            #        qc_opt: QuantumCircuit = QuantumCircuit.from_qasm_str(
            #            c_opt.to_qasm()
            #        )

            #        if check_metric(
            #            qc_opt,
            #            c_opt,
            #            child_graph,
            #            qc_best,
            #            c_best,
            #            g_best,
            #            metric=metric,
            #        ):
            #            qc_best = qc_opt
            #            c_best = c_opt
            #            g_best = child_graph

            #    except:
            #        pass

    if b_count_leafs:
        return qc_best, c_best, g_best, cntr_leafs
    else:
        return qc_best, c_best, g_best


def parallel_dfs(qc: QuantumCircuit, quiet: bool = True):

    c: Circuit = Circuit.from_qasm(qk.qasm2.dumps(qc))
    g: BaseGraph[VT, ET] = c.to_graph()

    n_cntr_colour: int = 2  # limit_gh_rg(qc, mult=2)

    LifoManager = BaseManager
    LifoManager.register("LifoQueue", LifoQueue)

    manager: BaseManager = LifoManager()

    manager.start()

    s: LifoQueue[Tuple[Callable, BaseGraph[VT, ET], int, LifoQueue[Any]] | None] = (
        manager.LifoQueue()
    )

    workers: list[Worker] = []

    for i in range(cpu_count()):
        worker: Worker = Worker(s)
        worker.start()
        workers.append(worker)

    s.put((visit_node, g, n_cntr_colour, s))
    s.join()

    for _ in workers:
        s.put(None)

    for worker in workers:
        worker.join()

    pass


def visit_node(
    stack_graph: BaseGraph[VT, ET],
    cntr_colour_change: int,
    s: LifoQueue[Tuple[Callable, BaseGraph[VT, ET], int, LifoQueue[Any]] | None],
):

    l_rules: list[Callable] = zx_rules()
    quiet: bool = True

    for rule in l_rules:

        child_graph: BaseGraph[VT, ET] = stack_graph.copy()

        if (
            rule.__name__ == "to_gh" or rule.__name__ == "to_rg"
        ) and cntr_colour_change > 0:
            cntr: int = rule(child_graph, quiet=quiet)
            cntr_colour_change -= 1
        if (
            rule.__name__ == "to_gh" or rule.__name__ == "to_rg"
        ) and cntr_colour_change == 0:
            continue
        else:
            cntr: int = rule(child_graph, quiet=quiet)

        if cntr > 0:
            s.put((visit_node, child_graph, cntr_colour_change, s))
        else:
            try:
                c_opt: Circuit = extract_circuit(child_graph.copy())
                qc_opt: QuantumCircuit = QuantumCircuit.from_qasm_str(c_opt.to_qasm())
            except:
                continue


class Worker(Process):
    def __init__(self, task_queue):
        super(Worker, self).__init__()
        self.task_queue = task_queue

    def run(self):
        for function, *args in iter(self.task_queue.get, None):
            # print(f"Running: {function.__name__}({*args,})")

            # Run the provided function with its parameters in child process
            function(*args)

            self.task_queue.task_done()  # <-- Notify queue that task is complete


def dfs_get_trace(end: int, dic_hist: dict) -> list[dict | None]:
    path: list[int] = [end]
    step: int = end
    l_path_tmp: list[dict | None] = [dic_hist[end]]

    while path[-1] != 0:
        step = dic_hist[step]["parent"]
        path.append(step)

        if step != 0:
            l_path_tmp.append(dic_hist[step])

    l_path_tmp.reverse()

    return l_path_tmp


def check_metric(
    qc_new: QuantumCircuit,
    c_new: Circuit,
    g_new: BaseGraph[VT, ET],
    qc_old: QuantumCircuit,
    c_old: Circuit,
    g_old: BaseGraph[VT, ET],
    metric: str = "c_tcount",
) -> bool:

    b_better: bool = False

    dic_qc_new: dict = get_circuit_statistics(qc_new, c_new, g_new)
    dic_qc_old: dict = get_circuit_statistics(qc_old, c_old, g_old)

    if dic_qc_new[metric] < dic_qc_old[metric]:
        b_better = True

    return b_better


def limit_gh_rg(qc: QuantumCircuit, mult: int = 2) -> int:

    c: Circuit = zx.Circuit.from_qasm(qk.qasm2.dumps(qc))
    g: BaseGraph[VT, ET] = c.to_graph()

    cntr_colour_change: int = full_analysis(g)

    # apply multiplicator if colour was changed more often, else just add mult as limit
    if cntr_colour_change > 0:
        cntr_colour_change *= mult
    else:
        cntr_colour_change += mult

    return cntr_colour_change
