"""
PYZX
"""

from copy import deepcopy
import qiskit as qk
from qiskit import QuantumCircuit
from qiskit.transpiler import PassManager
from qiskit.transpiler.passes import Unroll3qOrMore

import pyzx as zx
from pyzx.graph.base import BaseGraph, VT, ET
from pyzx.circuit import Circuit
from pyzx.simplify import to_graph_like, is_graph_like

from pyzx.simplify import (
    id_simp,
    spider_simp,
    pivot_simp,
    lcomp_simp,
    pivot_gadget_simp,
    pivot_boundary_simp,
    gadget_simp,
    supplementarity_simp,
    Stats,
    Simplifier,
    to_gh,
)

from pyzx.utils import VertexType

from typing import Tuple, Optional, Callable, Union

from dataclasses import dataclass, fields, make_dataclass
import dill
import datetime as DT


@dataclass
class Result:
    metric: str
    graph: BaseGraph[VT, ET]
    circuit: Circuit
    qc: QuantumCircuit
    expired_time: None | float


@dataclass
class AllResults:
    metric: Result | None = None
    qc_init: QuantumCircuit | None = None
    c_init: Circuit | None = None
    g_init: BaseGraph[VT, ET] | None = None


"""
sota method of pyzx
slighlty modified to stop after timelimit
"""


def interior_clifford_simp(
    g: BaseGraph[VT, ET],
    quiet: bool = False,
    stats: Optional[Stats] = None,
    end_time: DT.datetime = DT.datetime.now(),
) -> int:
    """Keeps doing the simplifications ``id_simp``, ``spider_simp``,
    ``pivot_simp`` and ``lcomp_simp`` until none of them can be applied anymore."""
    spider_simp(g, quiet=quiet, stats=stats)
    if DT.datetime.now() > end_time:
        return 0
    to_gh(g)
    if DT.datetime.now() > end_time:
        return 0
    i = 0
    while True:
        i1 = id_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break
        i2 = spider_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break

        i3 = pivot_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break

        i4 = lcomp_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break

        if i1 + i2 + i3 + i4 == 0:
            break
        i += 1
    return i


def clifford_simp(
    g: BaseGraph[VT, ET],
    quiet: bool = True,
    stats: Optional[Stats] = None,
    end_time: DT.datetime = DT.datetime.now(),
) -> int:
    """Keeps doing rounds of :func:`interior_clifford_simp` and
    :func:`pivot_boundary_simp` until they can't be applied anymore."""
    i = 0
    while True:
        i += interior_clifford_simp(g, quiet=quiet, stats=stats, end_time=end_time)
        if DT.datetime.now() > end_time:
            break
        i2 = pivot_boundary_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break

        if i2 == 0:
            break
    return i


def reduce_scalar(
    g: BaseGraph[VT, ET],
    quiet: bool = True,
    stats: Optional[Stats] = None,
    max_duration: int = 60 * 60 * 6,
) -> int:
    """Modification of ``full_reduce`` that is tailered for scalar ZX-diagrams.
    It skips the boundary pivots, and it additionally does ``supplementarity_simp``."""
    i = 0
    end_time = DT.datetime.now() + DT.timedelta(seconds=max_duration)
    while True:
        if DT.datetime.now() > end_time:
            break
        i1 = id_simp(g, quiet=quiet, stats=stats)

        if DT.datetime.now() > end_time:
            break
        i2 = spider_simp(g, quiet=quiet, stats=stats)

        if DT.datetime.now() > end_time:
            break
        i3 = pivot_simp(g, quiet=quiet, stats=stats)

        if DT.datetime.now() > end_time:
            break
        i4 = lcomp_simp(g, quiet=quiet, stats=stats)

        if i1 + i2 + i3 + i4:
            i += 1
            continue
        if DT.datetime.now() > end_time:
            break
        i5 = pivot_gadget_simp(g, quiet=quiet, stats=stats)

        if DT.datetime.now() > end_time:
            break
        i6 = gadget_simp(g, quiet=quiet, stats=stats)
        if i5 + i6:
            i += 1
            continue

        if DT.datetime.now() > end_time:
            break
        i7 = supplementarity_simp(g, quiet=quiet, stats=stats)
        if not i7:
            break
        i += 1
    return i


def full_reduce(
    g: BaseGraph[VT, ET],
    quiet: bool = True,
    stats: Optional[Stats] = None,
    max_duration: int = 60 * 60 * 6,
) -> None:
    """The main simplification routine of PyZX. It uses a combination of :func:`clifford_simp` and
    the gadgetization strategies :func:`pivot_gadget_simp` and :func:`gadget_simp`."""
    if any(g.types()[h] == VertexType.H_BOX for h in g.vertices()):
        raise ValueError(
            "Input graph is not a ZX-diagram as it contains an H-box. "
            "Maybe call pyzx.hsimplify.from_hypergraph_form(g) first?"
        )
    # endtime
    end_time = DT.datetime.now() + DT.timedelta(seconds=max_duration)

    interior_clifford_simp(g, quiet=quiet, stats=stats, end_time=end_time)
    if DT.datetime.now() > end_time:
        return None

    pivot_gadget_simp(g, quiet=quiet, stats=stats)
    if DT.datetime.now() > end_time:
        return None

    while True:
        clifford_simp(g, quiet=quiet, stats=stats, end_time=end_time)
        if DT.datetime.now() > end_time:
            break

        i = gadget_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break

        interior_clifford_simp(g, quiet=quiet, stats=stats, end_time=end_time)
        if DT.datetime.now() > end_time:
            break

        j = pivot_gadget_simp(g, quiet=quiet, stats=stats)
        if DT.datetime.now() > end_time:
            break

        if i + j == 0:
            break


def teleport_reduce(
    g: BaseGraph[VT, ET],
    quiet: bool = True,
    stats: Optional[Stats] = None,
    max_duration: int = 60 * 60 * 6,
) -> BaseGraph[VT, ET]:
    """This simplification procedure runs :func:`full_reduce` in a way
    that does not change the graph structure of the resulting diagram.
    The only thing that is different in the output graph are the location and value of the phases.
    """
    s = Simplifier(g)
    full_reduce(s.simplifygraph, quiet=quiet, stats=stats, max_duration=max_duration)
    return s.simplifygraph


def pyzx_full_reduce(
    qc: QuantumCircuit,
    circ_name: str | None,
    quiet: bool = True,
    metric: list[str] = ["c_tcount", "c_gates"],
    b_dump_results: bool = True,
    func_opt: str = "pyzx_full_reduce",
    max_duration: int = 60 * 60 * 6,
) -> AllResults:

    c: Circuit = zx.Circuit.from_qasm(qk.qasm2.dumps(qc))
    c_init = deepcopy(c)
    qc_init: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_init.to_qasm())

    g: BaseGraph[VT, ET] = c.to_basic_gates().to_graph()
    to_graph_like(g)
    g_init: BaseGraph[VT, ET] = deepcopy(g)
    start_time: DT.datetime = DT.datetime.now()
    full_reduce(g, quiet=quiet, max_duration=max_duration)
    # g.normalize()

    c_opt: Circuit = zx.extract_circuit(g.copy(), optimize_czs=True, optimize_cnots=3)
    qc_opt: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_opt.to_qasm())

    # unroll qc into one and two qubit gates
    qc_opt: QuantumCircuit = _unroll_qc_2q(qc_opt)

    end_time: DT.datetime = DT.datetime.now()
    # wrap results in data class
    all_results: AllResults = _generate_result_class(
        g_init,
        c_init,
        qc_init,
        g,
        c_opt,
        qc_opt,
        metric,
        (end_time - start_time).total_seconds(),
    )

    # dump pickled results object
    if b_dump_results:
        _dump_results(all_results, circ_name, func_opt)

    return all_results


"""
full reduce but not a change of the graph structure
"""


def pyzx_teleport_reduce(
    qc: QuantumCircuit,
    circ_name: str | None,
    quiet: bool = True,
    metric: list[str] = ["c_tcount", "c_gates"],
    b_dump_results: bool = True,
    func_opt: str = "pyzx_teleport_reduce",
    max_duration: int = 60 * 60 * 6,
) -> AllResults:

    c: Circuit = zx.Circuit.from_qasm(qk.qasm2.dumps(qc))
    c_init = deepcopy(c)
    qc_init: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_init.to_qasm())

    g: BaseGraph[VT, ET] = c.to_basic_gates().to_graph()
    to_graph_like(g)
    g_init: BaseGraph[VT, ET] = deepcopy(g)
    start_time: DT.datetime = DT.datetime.now()
    teleport_reduce(g, quiet=quiet, max_duration=max_duration)
    # g.normalize()

    c_opt: Circuit = zx.extract_circuit(g.copy(), optimize_czs=True, optimize_cnots=3)
    qc_opt: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_opt.to_qasm())

    # unroll qc into one and two qubit gates
    qc_opt: QuantumCircuit = _unroll_qc_2q(qc_opt)
    end_time: DT.datetime = DT.datetime.now()

    # wrap results in data class
    all_results: AllResults = _generate_result_class(
        g_init,
        c_init,
        qc_init,
        g,
        c_opt,
        qc_opt,
        metric,
        (end_time - start_time).total_seconds(),
    )

    # dump pickled results object
    if b_dump_results:
        _dump_results(all_results, circ_name, func_opt)

    return all_results


"""
full reduce but skipping boundary pivots and adds supplementariy simplification
"""


def pyzx_scalar_reduce(
    qc: QuantumCircuit,
    circ_name: str | None,
    quiet: bool = True,
    metric: list[str] = ["c_tcount", "c_gates"],
    b_dump_results: bool = True,
    func_opt: str = "scalar_reduce",
    max_duration: int = 60 * 60 * 6,
) -> AllResults:

    c: Circuit = zx.Circuit.from_qasm(qk.qasm2.dumps(qc))
    c_init = deepcopy(c)
    qc_init: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_init.to_qasm())

    g: BaseGraph[VT, ET] = c.to_basic_gates().to_graph()
    to_graph_like(g)
    g_init: BaseGraph[VT, ET] = deepcopy(g)
    start_time: DT.datetime = DT.datetime.now()
    reduce_scalar(g, quiet=quiet, max_duration=max_duration)
    # g.normalize()

    c_opt: Circuit = zx.extract_circuit(g.copy(), optimize_czs=True, optimize_cnots=3)
    qc_opt: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_opt.to_qasm())

    # unroll qc into one and two qubit gates
    qc_opt: QuantumCircuit = _unroll_qc_2q(qc_opt)
    end_time: DT.datetime = DT.datetime.now()

    # wrap results in data class
    all_results: AllResults = _generate_result_class(
        g_init,
        c_init,
        qc_init,
        g,
        c_opt,
        qc_opt,
        metric,
        (end_time - start_time).total_seconds(),
    )

    # dump pickled results object
    if b_dump_results:
        _dump_results(all_results, circ_name, func_opt)

    return all_results


"""
get statistics of the unoptimized circuit that was converted to and from pyzx
"""


def unoptimized(
    qc: QuantumCircuit,
    circ_name: str | None,
    quiet: bool = True,
    metric: list[str] = ["c_tcount", "c_gates"],
    b_dump_results: bool = True,
    func_opt: str = "unoptimized",
    max_duration: int = 60 * 60 * 6,
) -> AllResults:

    c: Circuit = zx.Circuit.from_qasm(qk.qasm2.dumps(qc))
    c_init = deepcopy(c)
    qc_init: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_init.to_qasm())

    g: BaseGraph[VT, ET] = c.to_basic_gates().to_graph()
    to_graph_like(g)
    g_init: BaseGraph[VT, ET] = deepcopy(g)
    start_time: DT.datetime = DT.datetime.now()

    c_opt: Circuit = zx.extract_circuit(g.copy(), optimize_czs=True, optimize_cnots=3)
    qc_opt: QuantumCircuit = qk.QuantumCircuit.from_qasm_str(c_opt.to_qasm())

    # unroll qc into one and two qubit gates
    qc_opt: QuantumCircuit = _unroll_qc_2q(qc_opt)
    end_time: DT.datetime = DT.datetime.now()

    # wrap results in data class
    all_results: AllResults = _generate_result_class(
        g_init,
        c_init,
        qc_init,
        g,
        c_opt,
        qc_opt,
        metric,
        (end_time - start_time).total_seconds(),
    )

    # dump pickled results object
    if b_dump_results:
        _dump_results(all_results, circ_name, func_opt)

    return all_results


"""
HELPER FUNCTIONS
"""

"""
generate dataclass that wraps the resulting graphs and results
"""


def _generate_result_class(
    g_init,
    c_init,
    qc_init,
    g: BaseGraph[VT, ET],
    c_opt: Circuit,
    qc_opt: QuantumCircuit,
    metric: list,
    duration: float,
) -> AllResults:
    AllResults = make_dataclass(
        "AllResults",
        [(m.__str__(), type(Result)) for m in metric]
        + [
            ("g_init", BaseGraph[VT, ET]),
            ("c_init", Circuit),
            ("qc_init", QuantumCircuit),
        ],
    )

    # initialize result data class
    for var in fields(AllResults):
        if var.name == "qc_init":
            setattr(AllResults, var.name, deepcopy(qc_init))
        elif var.name == "c_init":
            setattr(AllResults, var.name, deepcopy(c_init))
        elif var.name == "g_init":
            setattr(AllResults, var.name, deepcopy(g_init))

        for m in metric:
            setattr(
                AllResults,
                var.name,
                Result(m.__str__(), g.copy(), c_opt, qc_opt, duration),
            )

    return AllResults


"""
dump our objects deserialized so we can work with the graphs and circuits further
"""


def _dump_results(
    all_results: AllResults, circ_name: None | str, func_opt: str
) -> None:
    # save our results as a pickled object
    time_prefix: str = DT.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    if isinstance(circ_name, str):
        dill.dump(
            all_results,
            open(f"_{time_prefix}_{circ_name}_{func_opt}.pkl", "wb"),
            recurse=True,
        )
    else:
        dill.dump(
            all_results,
            open(f"_{time_prefix}_{circ_name}_{func_opt}.pkl", "wb"),
            recurse=True,
        )


"""
for comparison reasons, all circuits need to be in the same basis
"""


# ensure that we get comparable results and expand every quantum circuit into single and two qubit gates
def _unroll_qc_2q(qc: QuantumCircuit) -> QuantumCircuit:
    pm: PassManager = PassManager([Unroll3qOrMore()])
    qc_unrolled: QuantumCircuit = pm.run(qc)
    return qc_unrolled
