from __future__ import annotations
from collections.abc import Generator

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from typing import TYPE_CHECKING
if TYPE_CHECKING:  # TODO move to seaborn._core.typing?
    from seaborn._core.plot import FacetSpec, PairSpec
    from matplotlib.figure import SubFigure


class Subplots:
    """
    Interface for creating and using matplotlib subplots based on seaborn parameters.

    Parameters
    ----------
    subplot_spec : dict
        Keyword args for :meth:`matplotlib.figure.Figure.subplots`.
    facet_spec : dict
        Parameters that control subplot faceting.
    pair_spec : dict
        Parameters that control subplot pairing.
    data : PlotData
        Data used to define figure setup.

    """
    def __init__(
        self,
        subplot_spec: dict,  # TODO define as TypedDict
        facet_spec: FacetSpec,
        pair_spec: PairSpec,
    ):

        self.subplot_spec = subplot_spec

        self._check_dimension_uniqueness(facet_spec, pair_spec)
        self._determine_grid_dimensions(facet_spec, pair_spec)
        self._handle_wrapping(facet_spec, pair_spec)
        self._determine_axis_sharing(pair_spec)

    def _check_dimension_uniqueness(
        self, facet_spec: FacetSpec, pair_spec: PairSpec
    ) -> None:
        """Reject specs that pair and facet on (or wrap to) same figure dimension."""
        err = None

        facet_vars = facet_spec.get("variables", {})

        if facet_spec.get("wrap") and {"col", "row"} <= set(facet_vars):
            err = "Cannot wrap facets when specifying both `col` and `row`."
        elif (
            pair_spec.get("wrap")
            and pair_spec.get("cross", True)
            and len(pair_spec.get("structure", {}).get("x", [])) > 1
            and len(pair_spec.get("structure", {}).get("y", [])) > 1
        ):
            err = "Cannot wrap subplots when pairing on both `x` and `y`."

        collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]}
        for pair_axis, (multi_dim, wrap_dim) in collisions.items():
            if pair_axis not in pair_spec.get("structure", {}):
                continue
            elif multi_dim[:3] in facet_vars:
                err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``."
            elif wrap_dim[:3] in facet_vars and facet_spec.get("wrap"):
                err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``."
            elif wrap_dim[:3] in facet_vars and pair_spec.get("wrap"):
                err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}."

        if err is not None:
            raise RuntimeError(err)  # TODO what err class? Define PlotSpecError?

    def _determine_grid_dimensions(
        self, facet_spec: FacetSpec, pair_spec: PairSpec
    ) -> None:
        """Parse faceting and pairing information to define figure structure."""
        self.grid_dimensions: dict[str, list] = {}
        for dim, axis in zip(["col", "row"], ["x", "y"]):

            facet_vars = facet_spec.get("variables", {})
            if dim in facet_vars:
                self.grid_dimensions[dim] = facet_spec["structure"][dim]
            elif axis in pair_spec.get("structure", {}):
                self.grid_dimensions[dim] = [
                    None for _ in pair_spec.get("structure", {})[axis]
                ]
            else:
                self.grid_dimensions[dim] = [None]

            self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim])

        if not pair_spec.get("cross", True):
            self.subplot_spec["nrows"] = 1

        self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"]

    def _handle_wrapping(
        self, facet_spec: FacetSpec, pair_spec: PairSpec
    ) -> None:
        """Update figure structure parameters based on facet/pair wrapping."""
        self.wrap = wrap = facet_spec.get("wrap") or pair_spec.get("wrap")
        if not wrap:
            return

        wrap_dim = "row" if self.subplot_spec["nrows"] > 1 else "col"
        flow_dim = {"row": "col", "col": "row"}[wrap_dim]
        n_subplots = self.subplot_spec[f"n{wrap_dim}s"]
        flow = int(np.ceil(n_subplots / wrap))

        if wrap < self.subplot_spec[f"n{wrap_dim}s"]:
            self.subplot_spec[f"n{wrap_dim}s"] = wrap
        self.subplot_spec[f"n{flow_dim}s"] = flow
        self.n_subplots = n_subplots
        self.wrap_dim = wrap_dim

    def _determine_axis_sharing(self, pair_spec: PairSpec) -> None:
        """Update subplot spec with default or specified axis sharing parameters."""
        axis_to_dim = {"x": "col", "y": "row"}
        key: str
        val: str | bool
        for axis in "xy":
            key = f"share{axis}"
            # Always use user-specified value, if present
            if key not in self.subplot_spec:
                if axis in pair_spec.get("structure", {}):
                    # Paired axes are shared along one dimension by default
                    if self.wrap is None and pair_spec.get("cross", True):
                        val = axis_to_dim[axis]
                    else:
                        val = False
                else:
                    # This will pick up faceted plots, as well as single subplot
                    # figures, where the value doesn't really matter
                    val = True
                self.subplot_spec[key] = val

    def init_figure(
        self,
        pair_spec: PairSpec,
        pyplot: bool = False,
        figure_kws: dict | None = None,
        target: Axes | Figure | SubFigure = None,
    ) -> Figure:
        """Initialize matplotlib objects and add seaborn-relevant metadata."""
        # TODO reduce need to pass pair_spec here?

        if figure_kws is None:
            figure_kws = {}

        if isinstance(target, mpl.axes.Axes):

            if max(self.subplot_spec["nrows"], self.subplot_spec["ncols"]) > 1:
                err = " ".join([
                    "Cannot create multiple subplots after calling `Plot.on` with",
                    f"a {mpl.axes.Axes} object.",
                ])
                try:
                    err += f" You may want to use a {mpl.figure.SubFigure} instead."
                except AttributeError:  # SubFigure added in mpl 3.4
                    pass
                raise RuntimeError(err)

            self._subplot_list = [{
                "ax": target,
                "left": True,
                "right": True,
                "top": True,
                "bottom": True,
                "col": None,
                "row": None,
                "x": "x",
                "y": "y",
            }]
            self._figure = target.figure
            return self._figure

        elif (
            hasattr(mpl.figure, "SubFigure")  # Added in mpl 3.4
            and isinstance(target, mpl.figure.SubFigure)
        ):
            figure = target.figure
        elif isinstance(target, mpl.figure.Figure):
            figure = target
        else:
            if pyplot:
                figure = plt.figure(**figure_kws)
            else:
                figure = mpl.figure.Figure(**figure_kws)
            target = figure
        self._figure = figure

        axs = target.subplots(**self.subplot_spec, squeeze=False)

        if self.wrap:
            # Remove unused Axes and flatten the rest into a (2D) vector
            axs_flat = axs.ravel({"col": "C", "row": "F"}[self.wrap_dim])
            axs, extra = np.split(axs_flat, [self.n_subplots])
            for ax in extra:
                ax.remove()
            if self.wrap_dim == "col":
                axs = axs[np.newaxis, :]
            else:
                axs = axs[:, np.newaxis]

        # Get i, j coordinates for each Axes object
        # Note that i, j are with respect to faceting/pairing,
        # not the subplot grid itself, (which only matters in the case of wrapping).
        iter_axs: np.ndenumerate | zip
        if not pair_spec.get("cross", True):
            indices = np.arange(self.n_subplots)
            iter_axs = zip(zip(indices, indices), axs.flat)
        else:
            iter_axs = np.ndenumerate(axs)

        self._subplot_list = []
        for (i, j), ax in iter_axs:

            info = {"ax": ax}

            nrows, ncols = self.subplot_spec["nrows"], self.subplot_spec["ncols"]
            if not self.wrap:
                info["left"] = j % ncols == 0
                info["right"] = (j + 1) % ncols == 0
                info["top"] = i == 0
                info["bottom"] = i == nrows - 1
            elif self.wrap_dim == "col":
                info["left"] = j % ncols == 0
                info["right"] = ((j + 1) % ncols == 0) or ((j + 1) == self.n_subplots)
                info["top"] = j < ncols
                info["bottom"] = j >= (self.n_subplots - ncols)
            elif self.wrap_dim == "row":
                info["left"] = i < nrows
                info["right"] = i >= self.n_subplots - nrows
                info["top"] = i % nrows == 0
                info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots)

            if not pair_spec.get("cross", True):
                info["top"] = j < ncols
                info["bottom"] = j >= self.n_subplots - ncols

            for dim in ["row", "col"]:
                idx = {"row": i, "col": j}[dim]
                info[dim] = self.grid_dimensions[dim][idx]

            for axis in "xy":

                idx = {"x": j, "y": i}[axis]
                if axis in pair_spec.get("structure", {}):
                    key = f"{axis}{idx}"
                else:
                    key = axis
                info[axis] = key

            self._subplot_list.append(info)

        return figure

    def __iter__(self) -> Generator[dict, None, None]:  # TODO TypedDict?
        """Yield each subplot dictionary with Axes object and metadata."""
        yield from self._subplot_list

    def __len__(self) -> int:
        """Return the number of subplots in this figure."""
        return len(self._subplot_list)
