from typing import Tuple
from pyzx import extract_circuit
from pyzx.graph.base import BaseGraph, VT, ET
from pyzx.circuit import Circuit
from metric import MetricTcount


class PruningCondition:
    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        return False, dic_child

    def _check_metric(
        self,
        g_new: BaseGraph[VT, ET],
        g_old: BaseGraph[VT, ET],
        metric: list = [MetricTcount()],
    ) -> list[bool]:

        l_better: list[bool] = [False for _ in metric]
        for i, m in enumerate(metric):

            # update metric with current graphs
            m.__init__(g_old, g_new)

            # only update results if we are better
            if m.check():
                l_better[i] = True  # indicate that we are actually better

        return l_better


class PruneColourCycle(PruningCondition):

    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        b_skip_loop: bool = False

        if dic_parent["past_rule"] is None or dic_child["past_rule"] is None:
            return b_skip_loop, dic_child

        if dic_child["past_rule"] == "to_gh" or dic_child["past_rule"] == "to_rg":
            if dic_parent["past_rule"] == "to_gh" or dic_parent["past_rule"] == "to_rg":
                return True, dic_child

        return b_skip_loop, dic_child


class PruneColourChange(PruningCondition):

    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        b_skip_loop: bool = False

        # update colour cntr in child node
        dic_child: dict = self._update_colour_count(dic_child, dic_parent)

        if dic_child["past_rule"] == "to_gh" or dic_child["past_rule"] == "to_rg":
            if dic_child["cntr_colour_change"] < dic_child["n_colour_change"]:
                dic_child["cntr_colour_change"] += 1
            else:
                return True, dic_child

        return b_skip_loop, dic_child

    def _update_colour_count(self, dic_child: dict, dic_parent: dict) -> dict:
        dic_child["n_colour_change"] = dic_parent["n_colour_change"]
        dic_child["cntr_colour_change"] = dic_parent["cntr_colour_change"]

        return dic_child


class PruneMaxDepth(PruningCondition):

    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        b_skip_loop = False

        if dic_parent["cntr_max_depth"] >= dic_child["max_depth"]:
            b_skip_loop = True

        return b_skip_loop, dic_child

    def _update_depth(self, dic_child: dict, dic_parent: dict) -> dict:
        dic_child["cntr_max_depth"] = dic_parent["cntr_max_depth"] + 1

        if self.depth < dic_child["cntr_max_depth"]:
            self.depth = dic_child["cntr_max_depth"]

        return dic_child


class PruneOptimal(PruningCondition):

    # skip if one of the metrics is already optimal. Optimality is defined as a metric being 0
    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        b_skip_loop: bool = False

        for m in metric:
            # a metric is optimal if its 0
            if m._metric_count(dic_child["g"]) == 0:
                b_skip_loop = True

        return b_skip_loop, dic_child


class PruneNonCircuits(PruningCondition):

    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        b_skip_loop: bool = False
        try:
            c: Circuit = extract_circuit(dic_child["g"].copy())
        except Exception:
            b_skip_loop = True

        return b_skip_loop, dic_child


class PruneWorseCircuits(PruningCondition):

    def check(
        self, dic_child: dict, dic_parent: dict, metric: list
    ) -> Tuple[bool, dict]:
        b_skip_loop: bool = True

        # skip if none of the metric improves
        l_better: list[bool] = self._check_metric(
            dic_child["g"],
            dic_parent["g"],
            metric=metric,
        )

        for i in l_better:
            if i:
                return False, dic_child

        return b_skip_loop, dic_child
