from pyzx.graph.base import BaseGraph, ET, VT
from pyzx.circuit import Circuit
from pyzx import extract_circuit
from pyzx import tcount
from pyzx.simplify import is_graph_like, to_graph_like


class Metric:
    def __str__(self) -> str:
        return "generic_metric"

    def check(self, g_opt: BaseGraph[VT, ET], g_org: BaseGraph[VT, ET]) -> bool:
        b_better: bool = False

        try:
            org: int = self._metric_count(g_org)
            opt: int = self._metric_count(g_opt)

            if opt < org:
                b_better = True
        except:
            pass

        return b_better

    def _metric_count(self, g: BaseGraph) -> int:
        return 0


class MetricTcount(Metric):
    def __str__(self) -> str:
        return "g_tcount"

    def _metric_count(self, g: BaseGraph[VT, ET]) -> int:
        return sum(1 for v in g.phases().values() if v.denominator > 2)


class MetricTcountCircuit(Metric):

    def __str__(self) -> str:
        return "c_tcount"

    def _metric_count(self, g: BaseGraph[VT, ET]) -> int:
        # required because otherwise circuit extraction might fail
        if not is_graph_like(g):
            to_graph_like(g)
        c: Circuit = extract_circuit(
            g.copy(), optimize_czs=True, optimize_cnots=3
        ).to_basic_gates()
        return tcount(c)


class MetricTwoQubit(Metric):
    def __str__(self) -> str:
        return "c_two_q"

    def _metric_count(self, g: BaseGraph[VT, ET]) -> int:
        # required because otherwise circuit extraction might fail
        if not is_graph_like(g):
            to_graph_like(g)
        c: Circuit = extract_circuit(g).to_basic_gates()
        return sum(1 for g in c.gates if g.name in ("CNOT", "CZ"))


class MetricEdge(Metric):
    def __str__(self) -> str:
        return "g_edges"

    def _metric_count(self, g: BaseGraph[VT, ET]) -> int:
        return g.num_edges()


class MetricVertex(Metric):
    def __str__(self) -> str:
        return "g_vertices"

    def _metric_count(self, g: BaseGraph[VT, ET]) -> int:
        return g.num_vertices()
