"""Simplified split-apply-combine paradigm on dataframes for internal use."""
from __future__ import annotations

from typing import cast, Iterable

import pandas as pd

from seaborn._core.rules import categorical_order

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from typing import Callable
    from pandas import DataFrame, MultiIndex, Index


class GroupBy:
    """
    Interface for Pandas GroupBy operations allowing specified group order.

    Writing our own class to do this has a few advantages:
    - It constrains the interface between Plot and Stat/Move objects
    - It allows control over the row order of the GroupBy result, which is
      important when using in the context of some Move operations (dodge, stack, ...)
    - It simplifies some complexities regarding the return type and Index contents
      one encounters with Pandas, especially for DataFrame -> DataFrame applies
    - It increases future flexibility regarding alternate DataFrame libraries

    """
    def __init__(self, order: list[str] | dict[str, list | None]):
        """
        Initialize the GroupBy from grouping variables and optional level orders.

        Parameters
        ----------
        order
            List of variable names or dict mapping names to desired level orders.
            Level order values can be None to use default ordering rules. The
            variables can include names that are not expected to appear in the
            data; these will be dropped before the groups are defined.

        """
        if not order:
            raise ValueError("GroupBy requires at least one grouping variable")

        if isinstance(order, list):
            order = {k: None for k in order}
        self.order = order

    def _get_groups(
        self, data: DataFrame
    ) -> tuple[str | list[str], Index | MultiIndex]:
        """Return index with Cartesian product of ordered grouping variable levels."""
        levels = {}
        for var, order in self.order.items():
            if var in data:
                if order is None:
                    order = categorical_order(data[var])
                levels[var] = order

        grouper: str | list[str]
        groups: Index | MultiIndex
        if not levels:
            grouper = []
            groups = pd.Index([])
        elif len(levels) > 1:
            grouper = list(levels)
            groups = pd.MultiIndex.from_product(levels.values(), names=grouper)
        else:
            grouper, = list(levels)
            groups = pd.Index(levels[grouper], name=grouper)
        return grouper, groups

    def _reorder_columns(self, res, data):
        """Reorder result columns to match original order with new columns appended."""
        cols = [c for c in data if c in res]
        cols += [c for c in res if c not in data]
        return res.reindex(columns=pd.Index(cols))

    def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame:
        """
        Reduce each group to a single row in the output.

        The output will have a row for each unique combination of the grouping
        variable levels with null values for the aggregated variable(s) where
        those combinations do not appear in the dataset.

        """
        grouper, groups = self._get_groups(data)

        if not grouper:
            # We will need to see whether there are valid usecases that end up here
            raise ValueError("No grouping variables are present in dataframe")

        res = (
            data
            .groupby(grouper, sort=False, observed=True)
            .agg(*args, **kwargs)
            .reindex(groups)
            .reset_index()
            .pipe(self._reorder_columns, data)
        )

        return res

    def apply(
        self, data: DataFrame, func: Callable[..., DataFrame],
        *args, **kwargs,
    ) -> DataFrame:
        """Apply a DataFrame -> DataFrame mapping to each group."""
        grouper, groups = self._get_groups(data)

        if not grouper:
            return self._reorder_columns(func(data, *args, **kwargs), data)

        parts = {}
        for key, part_df in data.groupby(grouper, sort=False):
            parts[key] = func(part_df, *args, **kwargs)
        stack = []
        for key in groups:
            if key in parts:
                if isinstance(grouper, list):
                    # Implies that we had a MultiIndex so key is iterable
                    group_ids = dict(zip(grouper, cast(Iterable, key)))
                else:
                    group_ids = {grouper: key}
                stack.append(parts[key].assign(**group_ids))

        res = pd.concat(stack, ignore_index=True)
        return self._reorder_columns(res, data)
