from __future__ import annotations
from itertools import product
from inspect import signature
import warnings
from textwrap import dedent

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from ._oldcore import VectorPlotter, variable_type, categorical_order
from ._compat import share_axis
from . import utils
from .utils import (
    adjust_legend_subtitles, _check_argument, _draw_figure, _disable_autolayout
)
from .palettes import color_palette, blend_palette
from ._docstrings import (
    DocstringComponents,
    _core_docs,
)

__all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]


_param_docs = DocstringComponents.from_nested_components(
    core=_core_docs["params"],
)


class _BaseGrid:
    """Base class for grids of subplots."""

    def set(self, **kwargs):
        """Set attributes on each subplot Axes."""
        for ax in self.axes.flat:
            if ax is not None:  # Handle removed axes
                ax.set(**kwargs)
        return self

    @property
    def fig(self):
        """DEPRECATED: prefer the `figure` property."""
        # Grid.figure is preferred because it matches the Axes attribute name.
        # But as the maintanace burden on having this property is minimal,
        # let's be slow about formally deprecating it. For now just note its deprecation
        # in the docstring; add a warning in version 0.13, and eventually remove it.
        return self._figure

    @property
    def figure(self):
        """Access the :class:`matplotlib.figure.Figure` object underlying the grid."""
        return self._figure

    def apply(self, func, *args, **kwargs):
        """
        Pass the grid to a user-supplied function and return self.

        The `func` must accept an object of this type for its first
        positional argument. Additional arguments are passed through.
        The return value of `func` is ignored; this method returns self.
        See the `pipe` method if you want the return value.

        Added in v0.12.0.

        """
        func(self, *args, **kwargs)
        return self

    def pipe(self, func, *args, **kwargs):
        """
        Pass the grid to a user-supplied function and return its value.

        The `func` must accept an object of this type for its first
        positional argument. Additional arguments are passed through.
        The return value of `func` becomes the return value of this method.
        See the `apply` method if you want to return self instead.

        Added in v0.12.0.

        """
        return func(self, *args, **kwargs)

    def savefig(self, *args, **kwargs):
        """
        Save an image of the plot.

        This wraps :meth:`matplotlib.figure.Figure.savefig`, using bbox_inches="tight"
        by default. Parameters are passed through to the matplotlib function.

        """
        kwargs = kwargs.copy()
        kwargs.setdefault("bbox_inches", "tight")
        self.figure.savefig(*args, **kwargs)


class Grid(_BaseGrid):
    """A grid that can have multiple subplots and an external legend."""
    _margin_titles = False
    _legend_out = True

    def __init__(self):

        self._tight_layout_rect = [0, 0, 1, 1]
        self._tight_layout_pad = None

        # This attribute is set externally and is a hack to handle newer functions that
        # don't add proxy artists onto the Axes. We need an overall cleaner approach.
        self._extract_legend_handles = False

    def tight_layout(self, *args, **kwargs):
        """Call fig.tight_layout within rect that exclude the legend."""
        kwargs = kwargs.copy()
        kwargs.setdefault("rect", self._tight_layout_rect)
        if self._tight_layout_pad is not None:
            kwargs.setdefault("pad", self._tight_layout_pad)
        self._figure.tight_layout(*args, **kwargs)
        return self

    def add_legend(self, legend_data=None, title=None, label_order=None,
                   adjust_subtitles=False, **kwargs):
        """Draw a legend, maybe placing it outside axes and resizing the figure.

        Parameters
        ----------
        legend_data : dict
            Dictionary mapping label names (or two-element tuples where the
            second element is a label name) to matplotlib artist handles. The
            default reads from ``self._legend_data``.
        title : string
            Title for the legend. The default reads from ``self._hue_var``.
        label_order : list of labels
            The order that the legend entries should appear in. The default
            reads from ``self.hue_names``.
        adjust_subtitles : bool
            If True, modify entries with invisible artists to left-align
            the labels and set the font size to that of a title.
        kwargs : key, value pairings
            Other keyword arguments are passed to the underlying legend methods
            on the Figure or Axes object.

        Returns
        -------
        self : Grid instance
            Returns self for easy chaining.

        """
        # Find the data for the legend
        if legend_data is None:
            legend_data = self._legend_data
        if label_order is None:
            if self.hue_names is None:
                label_order = list(legend_data.keys())
            else:
                label_order = list(map(utils.to_utf8, self.hue_names))

        blank_handle = mpl.patches.Patch(alpha=0, linewidth=0)
        handles = [legend_data.get(l, blank_handle) for l in label_order]
        title = self._hue_var if title is None else title
        title_size = mpl.rcParams["legend.title_fontsize"]

        # Unpack nested labels from a hierarchical legend
        labels = []
        for entry in label_order:
            if isinstance(entry, tuple):
                _, label = entry
            else:
                label = entry
            labels.append(label)

        # Set default legend kwargs
        kwargs.setdefault("scatterpoints", 1)

        if self._legend_out:

            kwargs.setdefault("frameon", False)
            kwargs.setdefault("loc", "center right")

            # Draw a full-figure legend outside the grid
            figlegend = self._figure.legend(handles, labels, **kwargs)

            self._legend = figlegend
            figlegend.set_title(title, prop={"size": title_size})

            if adjust_subtitles:
                adjust_legend_subtitles(figlegend)

            # Draw the plot to set the bounding boxes correctly
            _draw_figure(self._figure)

            # Calculate and set the new width of the figure so the legend fits
            legend_width = figlegend.get_window_extent().width / self._figure.dpi
            fig_width, fig_height = self._figure.get_size_inches()
            self._figure.set_size_inches(fig_width + legend_width, fig_height)

            # Draw the plot again to get the new transformations
            _draw_figure(self._figure)

            # Now calculate how much space we need on the right side
            legend_width = figlegend.get_window_extent().width / self._figure.dpi
            space_needed = legend_width / (fig_width + legend_width)
            margin = .04 if self._margin_titles else .01
            self._space_needed = margin + space_needed
            right = 1 - self._space_needed

            # Place the subplot axes to give space for the legend
            self._figure.subplots_adjust(right=right)
            self._tight_layout_rect[2] = right

        else:
            # Draw a legend in the first axis
            ax = self.axes.flat[0]
            kwargs.setdefault("loc", "best")

            leg = ax.legend(handles, labels, **kwargs)
            leg.set_title(title, prop={"size": title_size})
            self._legend = leg

            if adjust_subtitles:
                adjust_legend_subtitles(leg)

        return self

    def _update_legend_data(self, ax):
        """Extract the legend data from an axes object and save it."""
        data = {}

        # Get data directly from the legend, which is necessary
        # for newer functions that don't add labeled proxy artists
        if ax.legend_ is not None and self._extract_legend_handles:
            handles = ax.legend_.legendHandles
            labels = [t.get_text() for t in ax.legend_.texts]
            data.update({l: h for h, l in zip(handles, labels)})

        handles, labels = ax.get_legend_handles_labels()
        data.update({l: h for h, l in zip(handles, labels)})

        self._legend_data.update(data)

        # Now clear the legend
        ax.legend_ = None

    def _get_palette(self, data, hue, hue_order, palette):
        """Get a list of colors for the hue variable."""
        if hue is None:
            palette = color_palette(n_colors=1)

        else:
            hue_names = categorical_order(data[hue], hue_order)
            n_colors = len(hue_names)

            # By default use either the current color palette or HUSL
            if palette is None:
                current_palette = utils.get_color_cycle()
                if n_colors > len(current_palette):
                    colors = color_palette("husl", n_colors)
                else:
                    colors = color_palette(n_colors=n_colors)

            # Allow for palette to map from hue variable names
            elif isinstance(palette, dict):
                color_names = [palette[h] for h in hue_names]
                colors = color_palette(color_names, n_colors)

            # Otherwise act as if we just got a list of colors
            else:
                colors = color_palette(palette, n_colors)

            palette = color_palette(colors, n_colors)

        return palette

    @property
    def legend(self):
        """The :class:`matplotlib.legend.Legend` object, if present."""
        try:
            return self._legend
        except AttributeError:
            return None

    def tick_params(self, axis='both', **kwargs):
        """Modify the ticks, tick labels, and gridlines.

        Parameters
        ----------
        axis : {'x', 'y', 'both'}
            The axis on which to apply the formatting.
        kwargs : keyword arguments
            Additional keyword arguments to pass to
            :meth:`matplotlib.axes.Axes.tick_params`.

        Returns
        -------
        self : Grid instance
            Returns self for easy chaining.

        """
        for ax in self.figure.axes:
            ax.tick_params(axis=axis, **kwargs)
        return self


_facet_docs = dict(

    data=dedent("""\
    data : DataFrame
        Tidy ("long-form") dataframe where each column is a variable and each
        row is an observation.\
    """),
    rowcol=dedent("""\
    row, col : vectors or keys in ``data``
        Variables that define subsets to plot on different facets.\
    """),
    rowcol_order=dedent("""\
    {row,col}_order : vector of strings
        Specify the order in which levels of the ``row`` and/or ``col`` variables
        appear in the grid of subplots.\
    """),
    col_wrap=dedent("""\
    col_wrap : int
        "Wrap" the column variable at this width, so that the column facets
        span multiple rows. Incompatible with a ``row`` facet.\
    """),
    share_xy=dedent("""\
    share{x,y} : bool, 'col', or 'row' optional
        If true, the facets will share y axes across columns and/or x axes
        across rows.\
    """),
    height=dedent("""\
    height : scalar
        Height (in inches) of each facet. See also: ``aspect``.\
    """),
    aspect=dedent("""\
    aspect : scalar
        Aspect ratio of each facet, so that ``aspect * height`` gives the width
        of each facet in inches.\
    """),
    palette=dedent("""\
    palette : palette name, list, or dict
        Colors to use for the different levels of the ``hue`` variable. Should
        be something that can be interpreted by :func:`color_palette`, or a
        dictionary mapping hue levels to matplotlib colors.\
    """),
    legend_out=dedent("""\
    legend_out : bool
        If ``True``, the figure size will be extended, and the legend will be
        drawn outside the plot on the center right.\
    """),
    margin_titles=dedent("""\
    margin_titles : bool
        If ``True``, the titles for the row variable are drawn to the right of
        the last column. This option is experimental and may not work in all
        cases.\
    """),
    facet_kws=dedent("""\
    facet_kws : dict
        Additional parameters passed to :class:`FacetGrid`.
    """),
)


class FacetGrid(Grid):
    """Multi-plot grid for plotting conditional relationships."""

    def __init__(
        self, data, *,
        row=None, col=None, hue=None, col_wrap=None,
        sharex=True, sharey=True, height=3, aspect=1, palette=None,
        row_order=None, col_order=None, hue_order=None, hue_kws=None,
        dropna=False, legend_out=True, despine=True,
        margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
        gridspec_kws=None,
    ):

        super().__init__()

        # Determine the hue facet layer information
        hue_var = hue
        if hue is None:
            hue_names = None
        else:
            hue_names = categorical_order(data[hue], hue_order)

        colors = self._get_palette(data, hue, hue_order, palette)

        # Set up the lists of names for the row and column facet variables
        if row is None:
            row_names = []
        else:
            row_names = categorical_order(data[row], row_order)

        if col is None:
            col_names = []
        else:
            col_names = categorical_order(data[col], col_order)

        # Additional dict of kwarg -> list of values for mapping the hue var
        hue_kws = hue_kws if hue_kws is not None else {}

        # Make a boolean mask that is True anywhere there is an NA
        # value in one of the faceting variables, but only if dropna is True
        none_na = np.zeros(len(data), bool)
        if dropna:
            row_na = none_na if row is None else data[row].isnull()
            col_na = none_na if col is None else data[col].isnull()
            hue_na = none_na if hue is None else data[hue].isnull()
            not_na = ~(row_na | col_na | hue_na)
        else:
            not_na = ~none_na

        # Compute the grid shape
        ncol = 1 if col is None else len(col_names)
        nrow = 1 if row is None else len(row_names)
        self._n_facets = ncol * nrow

        self._col_wrap = col_wrap
        if col_wrap is not None:
            if row is not None:
                err = "Cannot use `row` and `col_wrap` together."
                raise ValueError(err)
            ncol = col_wrap
            nrow = int(np.ceil(len(col_names) / col_wrap))
        self._ncol = ncol
        self._nrow = nrow

        # Calculate the base figure size
        # This can get stretched later by a legend
        # TODO this doesn't account for axis labels
        figsize = (ncol * height * aspect, nrow * height)

        # Validate some inputs
        if col_wrap is not None:
            margin_titles = False

        # Build the subplot keyword dictionary
        subplot_kws = {} if subplot_kws is None else subplot_kws.copy()
        gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy()
        if xlim is not None:
            subplot_kws["xlim"] = xlim
        if ylim is not None:
            subplot_kws["ylim"] = ylim

        # --- Initialize the subplot grid

        with _disable_autolayout():
            fig = plt.figure(figsize=figsize)

        if col_wrap is None:

            kwargs = dict(squeeze=False,
                          sharex=sharex, sharey=sharey,
                          subplot_kw=subplot_kws,
                          gridspec_kw=gridspec_kws)

            axes = fig.subplots(nrow, ncol, **kwargs)

            if col is None and row is None:
                axes_dict = {}
            elif col is None:
                axes_dict = dict(zip(row_names, axes.flat))
            elif row is None:
                axes_dict = dict(zip(col_names, axes.flat))
            else:
                facet_product = product(row_names, col_names)
                axes_dict = dict(zip(facet_product, axes.flat))

        else:

            # If wrapping the col variable we need to make the grid ourselves
            if gridspec_kws:
                warnings.warn("`gridspec_kws` ignored when using `col_wrap`")

            n_axes = len(col_names)
            axes = np.empty(n_axes, object)
            axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws)
            if sharex:
                subplot_kws["sharex"] = axes[0]
            if sharey:
                subplot_kws["sharey"] = axes[0]
            for i in range(1, n_axes):
                axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws)

            axes_dict = dict(zip(col_names, axes))

        # --- Set up the class attributes

        # Attributes that are part of the public API but accessed through
        # a  property so that Sphinx adds them to the auto class doc
        self._figure = fig
        self._axes = axes
        self._axes_dict = axes_dict
        self._legend = None

        # Public attributes that aren't explicitly documented
        # (It's not obvious that having them be public was a good idea)
        self.data = data
        self.row_names = row_names
        self.col_names = col_names
        self.hue_names = hue_names
        self.hue_kws = hue_kws

        # Next the private variables
        self._nrow = nrow
        self._row_var = row
        self._ncol = ncol
        self._col_var = col

        self._margin_titles = margin_titles
        self._margin_titles_texts = []
        self._col_wrap = col_wrap
        self._hue_var = hue_var
        self._colors = colors
        self._legend_out = legend_out
        self._legend_data = {}
        self._x_var = None
        self._y_var = None
        self._sharex = sharex
        self._sharey = sharey
        self._dropna = dropna
        self._not_na = not_na

        # --- Make the axes look good

        self.set_titles()
        self.tight_layout()

        if despine:
            self.despine()

        if sharex in [True, 'col']:
            for ax in self._not_bottom_axes:
                for label in ax.get_xticklabels():
                    label.set_visible(False)
                ax.xaxis.offsetText.set_visible(False)
                ax.xaxis.label.set_visible(False)

        if sharey in [True, 'row']:
            for ax in self._not_left_axes:
                for label in ax.get_yticklabels():
                    label.set_visible(False)
                ax.yaxis.offsetText.set_visible(False)
                ax.yaxis.label.set_visible(False)

    __init__.__doc__ = dedent("""\
        Initialize the matplotlib figure and FacetGrid object.

        This class maps a dataset onto multiple axes arrayed in a grid of rows
        and columns that correspond to *levels* of variables in the dataset.
        The plots it produces are often called "lattice", "trellis", or
        "small-multiple" graphics.

        It can also represent levels of a third variable with the ``hue``
        parameter, which plots different subsets of data in different colors.
        This uses color to resolve elements on a third dimension, but only
        draws subsets on top of each other and will not tailor the ``hue``
        parameter for the specific visualization the way that axes-level
        functions that accept ``hue`` will.

        The basic workflow is to initialize the :class:`FacetGrid` object with
        the dataset and the variables that are used to structure the grid. Then
        one or more plotting functions can be applied to each subset by calling
        :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe`. Finally, the
        plot can be tweaked with other methods to do things like change the
        axis labels, use different ticks, or add a legend. See the detailed
        code examples below for more information.

        .. warning::

            When using seaborn functions that infer semantic mappings from a
            dataset, care must be taken to synchronize those mappings across
            facets (e.g., by defining the ``hue`` mapping with a palette dict or
            setting the data type of the variables to ``category``). In most cases,
            it will be better to use a figure-level function (e.g. :func:`relplot`
            or :func:`catplot`) than to use :class:`FacetGrid` directly.

        See the :ref:`tutorial <grid_tutorial>` for more information.

        Parameters
        ----------
        {data}
        row, col, hue : strings
            Variables that define subsets of the data, which will be drawn on
            separate facets in the grid. See the ``{{var}}_order`` parameters to
            control the order of levels of this variable.
        {col_wrap}
        {share_xy}
        {height}
        {aspect}
        {palette}
        {{row,col,hue}}_order : lists
            Order for the levels of the faceting variables. By default, this
            will be the order that the levels appear in ``data`` or, if the
            variables are pandas categoricals, the category order.
        hue_kws : dictionary of param -> list of values mapping
            Other keyword arguments to insert into the plotting call to let
            other plot attributes vary across levels of the hue variable (e.g.
            the markers in a scatterplot).
        {legend_out}
        despine : boolean
            Remove the top and right spines from the plots.
        {margin_titles}
        {{x, y}}lim: tuples
            Limits for each of the axes on each facet (only relevant when
            share{{x, y}} is True).
        subplot_kws : dict
            Dictionary of keyword arguments passed to matplotlib subplot(s)
            methods.
        gridspec_kws : dict
            Dictionary of keyword arguments passed to
            :class:`matplotlib.gridspec.GridSpec`
            (via :meth:`matplotlib.figure.Figure.subplots`).
            Ignored if ``col_wrap`` is not ``None``.

        See Also
        --------
        PairGrid : Subplot grid for plotting pairwise relationships
        relplot : Combine a relational plot and a :class:`FacetGrid`
        displot : Combine a distribution plot and a :class:`FacetGrid`
        catplot : Combine a categorical plot and a :class:`FacetGrid`
        lmplot : Combine a regression plot and a :class:`FacetGrid`

        Examples
        --------

        .. note::

            These examples use seaborn functions to demonstrate some of the
            advanced features of the class, but in most cases you will want
            to use figue-level functions (e.g. :func:`displot`, :func:`relplot`)
            to make the plots shown here.

        .. include:: ../docstrings/FacetGrid.rst

        """).format(**_facet_docs)

    def facet_data(self):
        """Generator for name indices and data subsets for each facet.

        Yields
        ------
        (i, j, k), data_ijk : tuple of ints, DataFrame
            The ints provide an index into the {row, col, hue}_names attribute,
            and the dataframe contains a subset of the full data corresponding
            to each facet. The generator yields subsets that correspond with
            the self.axes.flat iterator, or self.axes[i, j] when `col_wrap`
            is None.

        """
        data = self.data

        # Construct masks for the row variable
        if self.row_names:
            row_masks = [data[self._row_var] == n for n in self.row_names]
        else:
            row_masks = [np.repeat(True, len(self.data))]

        # Construct masks for the column variable
        if self.col_names:
            col_masks = [data[self._col_var] == n for n in self.col_names]
        else:
            col_masks = [np.repeat(True, len(self.data))]

        # Construct masks for the hue variable
        if self.hue_names:
            hue_masks = [data[self._hue_var] == n for n in self.hue_names]
        else:
            hue_masks = [np.repeat(True, len(self.data))]

        # Here is the main generator loop
        for (i, row), (j, col), (k, hue) in product(enumerate(row_masks),
                                                    enumerate(col_masks),
                                                    enumerate(hue_masks)):
            data_ijk = data[row & col & hue & self._not_na]
            yield (i, j, k), data_ijk

    def map(self, func, *args, **kwargs):
        """Apply a plotting function to each facet's subset of the data.

        Parameters
        ----------
        func : callable
            A plotting function that takes data and keyword arguments. It
            must plot to the currently active matplotlib Axes and take a
            `color` keyword argument. If faceting on the `hue` dimension,
            it must also take a `label` keyword argument.
        args : strings
            Column names in self.data that identify variables with data to
            plot. The data for each variable is passed to `func` in the
            order the variables are specified in the call.
        kwargs : keyword arguments
            All keyword arguments are passed to the plotting function.

        Returns
        -------
        self : object
            Returns self.

        """
        # If color was a keyword argument, grab it here
        kw_color = kwargs.pop("color", None)

        # How we use the function depends on where it comes from
        func_module = str(getattr(func, "__module__", ""))

        # Check for categorical plots without order information
        if func_module == "seaborn.categorical":
            if "order" not in kwargs:
                warning = ("Using the {} function without specifying "
                           "`order` is likely to produce an incorrect "
                           "plot.".format(func.__name__))
                warnings.warn(warning)
            if len(args) == 3 and "hue_order" not in kwargs:
                warning = ("Using the {} function without specifying "
                           "`hue_order` is likely to produce an incorrect "
                           "plot.".format(func.__name__))
                warnings.warn(warning)

        # Iterate over the data subsets
        for (row_i, col_j, hue_k), data_ijk in self.facet_data():

            # If this subset is null, move on
            if not data_ijk.values.size:
                continue

            # Get the current axis
            modify_state = not func_module.startswith("seaborn")
            ax = self.facet_axis(row_i, col_j, modify_state)

            # Decide what color to plot with
            kwargs["color"] = self._facet_color(hue_k, kw_color)

            # Insert the other hue aesthetics if appropriate
            for kw, val_list in self.hue_kws.items():
                kwargs[kw] = val_list[hue_k]

            # Insert a label in the keyword arguments for the legend
            if self._hue_var is not None:
                kwargs["label"] = utils.to_utf8(self.hue_names[hue_k])

            # Get the actual data we are going to plot with
            plot_data = data_ijk[list(args)]
            if self._dropna:
                plot_data = plot_data.dropna()
            plot_args = [v for k, v in plot_data.items()]

            # Some matplotlib functions don't handle pandas objects correctly
            if func_module.startswith("matplotlib"):
                plot_args = [v.values for v in plot_args]

            # Draw the plot
            self._facet_plot(func, ax, plot_args, kwargs)

        # Finalize the annotations and layout
        self._finalize_grid(args[:2])

        return self

    def map_dataframe(self, func, *args, **kwargs):
        """Like ``.map`` but passes args as strings and inserts data in kwargs.

        This method is suitable for plotting with functions that accept a
        long-form DataFrame as a `data` keyword argument and access the
        data in that DataFrame using string variable names.

        Parameters
        ----------
        func : callable
            A plotting function that takes data and keyword arguments. Unlike
            the `map` method, a function used here must "understand" Pandas
            objects. It also must plot to the currently active matplotlib Axes
            and take a `color` keyword argument. If faceting on the `hue`
            dimension, it must also take a `label` keyword argument.
        args : strings
            Column names in self.data that identify variables with data to
            plot. The data for each variable is passed to `func` in the
            order the variables are specified in the call.
        kwargs : keyword arguments
            All keyword arguments are passed to the plotting function.

        Returns
        -------
        self : object
            Returns self.

        """

        # If color was a keyword argument, grab it here
        kw_color = kwargs.pop("color", None)

        # Iterate over the data subsets
        for (row_i, col_j, hue_k), data_ijk in self.facet_data():

            # If this subset is null, move on
            if not data_ijk.values.size:
                continue

            # Get the current axis
            modify_state = not str(func.__module__).startswith("seaborn")
            ax = self.facet_axis(row_i, col_j, modify_state)

            # Decide what color to plot with
            kwargs["color"] = self._facet_color(hue_k, kw_color)

            # Insert the other hue aesthetics if appropriate
            for kw, val_list in self.hue_kws.items():
                kwargs[kw] = val_list[hue_k]

            # Insert a label in the keyword arguments for the legend
            if self._hue_var is not None:
                kwargs["label"] = self.hue_names[hue_k]

            # Stick the facet dataframe into the kwargs
            if self._dropna:
                data_ijk = data_ijk.dropna()
            kwargs["data"] = data_ijk

            # Draw the plot
            self._facet_plot(func, ax, args, kwargs)

        # For axis labels, prefer to use positional args for backcompat
        # but also extract the x/y kwargs and use if no corresponding arg
        axis_labels = [kwargs.get("x", None), kwargs.get("y", None)]
        for i, val in enumerate(args[:2]):
            axis_labels[i] = val
        self._finalize_grid(axis_labels)

        return self

    def _facet_color(self, hue_index, kw_color):

        color = self._colors[hue_index]
        if kw_color is not None:
            return kw_color
        elif color is not None:
            return color

    def _facet_plot(self, func, ax, plot_args, plot_kwargs):

        # Draw the plot
        if str(func.__module__).startswith("seaborn"):
            plot_kwargs = plot_kwargs.copy()
            semantics = ["x", "y", "hue", "size", "style"]
            for key, val in zip(semantics, plot_args):
                plot_kwargs[key] = val
            plot_args = []
            plot_kwargs["ax"] = ax
        func(*plot_args, **plot_kwargs)

        # Sort out the supporting information
        self._update_legend_data(ax)

    def _finalize_grid(self, axlabels):
        """Finalize the annotations and layout."""
        self.set_axis_labels(*axlabels)
        self.tight_layout()

    def facet_axis(self, row_i, col_j, modify_state=True):
        """Make the axis identified by these indices active and return it."""

        # Calculate the actual indices of the axes to plot on
        if self._col_wrap is not None:
            ax = self.axes.flat[col_j]
        else:
            ax = self.axes[row_i, col_j]

        # Get a reference to the axes object we want, and make it active
        if modify_state:
            plt.sca(ax)
        return ax

    def despine(self, **kwargs):
        """Remove axis spines from the facets."""
        utils.despine(self._figure, **kwargs)
        return self

    def set_axis_labels(self, x_var=None, y_var=None, clear_inner=True, **kwargs):
        """Set axis labels on the left column and bottom row of the grid."""
        if x_var is not None:
            self._x_var = x_var
            self.set_xlabels(x_var, clear_inner=clear_inner, **kwargs)
        if y_var is not None:
            self._y_var = y_var
            self.set_ylabels(y_var, clear_inner=clear_inner, **kwargs)

        return self

    def set_xlabels(self, label=None, clear_inner=True, **kwargs):
        """Label the x axis on the bottom row of the grid."""
        if label is None:
            label = self._x_var
        for ax in self._bottom_axes:
            ax.set_xlabel(label, **kwargs)
        if clear_inner:
            for ax in self._not_bottom_axes:
                ax.set_xlabel("")
        return self

    def set_ylabels(self, label=None, clear_inner=True, **kwargs):
        """Label the y axis on the left column of the grid."""
        if label is None:
            label = self._y_var
        for ax in self._left_axes:
            ax.set_ylabel(label, **kwargs)
        if clear_inner:
            for ax in self._not_left_axes:
                ax.set_ylabel("")
        return self

    def set_xticklabels(self, labels=None, step=None, **kwargs):
        """Set x axis tick labels of the grid."""
        for ax in self.axes.flat:
            curr_ticks = ax.get_xticks()
            ax.set_xticks(curr_ticks)
            if labels is None:
                curr_labels = [l.get_text() for l in ax.get_xticklabels()]
                if step is not None:
                    xticks = ax.get_xticks()[::step]
                    curr_labels = curr_labels[::step]
                    ax.set_xticks(xticks)
                ax.set_xticklabels(curr_labels, **kwargs)
            else:
                ax.set_xticklabels(labels, **kwargs)
        return self

    def set_yticklabels(self, labels=None, **kwargs):
        """Set y axis tick labels on the left column of the grid."""
        for ax in self.axes.flat:
            curr_ticks = ax.get_yticks()
            ax.set_yticks(curr_ticks)
            if labels is None:
                curr_labels = [l.get_text() for l in ax.get_yticklabels()]
                ax.set_yticklabels(curr_labels, **kwargs)
            else:
                ax.set_yticklabels(labels, **kwargs)
        return self

    def set_titles(self, template=None, row_template=None, col_template=None,
                   **kwargs):
        """Draw titles either above each facet or on the grid margins.

        Parameters
        ----------
        template : string
            Template for all titles with the formatting keys {col_var} and
            {col_name} (if using a `col` faceting variable) and/or {row_var}
            and {row_name} (if using a `row` faceting variable).
        row_template:
            Template for the row variable when titles are drawn on the grid
            margins. Must have {row_var} and {row_name} formatting keys.
        col_template:
            Template for the row variable when titles are drawn on the grid
            margins. Must have {col_var} and {col_name} formatting keys.

        Returns
        -------
        self: object
            Returns self.

        """
        args = dict(row_var=self._row_var, col_var=self._col_var)
        kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"])

        # Establish default templates
        if row_template is None:
            row_template = "{row_var} = {row_name}"
        if col_template is None:
            col_template = "{col_var} = {col_name}"
        if template is None:
            if self._row_var is None:
                template = col_template
            elif self._col_var is None:
                template = row_template
            else:
                template = " | ".join([row_template, col_template])

        row_template = utils.to_utf8(row_template)
        col_template = utils.to_utf8(col_template)
        template = utils.to_utf8(template)

        if self._margin_titles:

            # Remove any existing title texts
            for text in self._margin_titles_texts:
                text.remove()
            self._margin_titles_texts = []

            if self.row_names is not None:
                # Draw the row titles on the right edge of the grid
                for i, row_name in enumerate(self.row_names):
                    ax = self.axes[i, -1]
                    args.update(dict(row_name=row_name))
                    title = row_template.format(**args)
                    text = ax.annotate(
                        title, xy=(1.02, .5), xycoords="axes fraction",
                        rotation=270, ha="left", va="center",
                        **kwargs
                    )
                    self._margin_titles_texts.append(text)

            if self.col_names is not None:
                # Draw the column titles  as normal titles
                for j, col_name in enumerate(self.col_names):
                    args.update(dict(col_name=col_name))
                    title = col_template.format(**args)
                    self.axes[0, j].set_title(title, **kwargs)

            return self

        # Otherwise title each facet with all the necessary information
        if (self._row_var is not None) and (self._col_var is not None):
            for i, row_name in enumerate(self.row_names):
                for j, col_name in enumerate(self.col_names):
                    args.update(dict(row_name=row_name, col_name=col_name))
                    title = template.format(**args)
                    self.axes[i, j].set_title(title, **kwargs)
        elif self.row_names is not None and len(self.row_names):
            for i, row_name in enumerate(self.row_names):
                args.update(dict(row_name=row_name))
                title = template.format(**args)
                self.axes[i, 0].set_title(title, **kwargs)
        elif self.col_names is not None and len(self.col_names):
            for i, col_name in enumerate(self.col_names):
                args.update(dict(col_name=col_name))
                title = template.format(**args)
                # Index the flat array so col_wrap works
                self.axes.flat[i].set_title(title, **kwargs)
        return self

    def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):
        """Add a reference line(s) to each facet.

        Parameters
        ----------
        x, y : numeric
            Value(s) to draw the line(s) at.
        color : :mod:`matplotlib color <matplotlib.colors>`
            Specifies the color of the reference line(s). Pass ``color=None`` to
            use ``hue`` mapping.
        linestyle : str
            Specifies the style of the reference line(s).
        line_kws : key, value mappings
            Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
            when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
            is not None.

        Returns
        -------
        :class:`FacetGrid` instance
            Returns ``self`` for easy method chaining.

        """
        line_kws['color'] = color
        line_kws['linestyle'] = linestyle

        if x is not None:
            self.map(plt.axvline, x=x, **line_kws)

        if y is not None:
            self.map(plt.axhline, y=y, **line_kws)

        return self

    # ------ Properties that are part of the public API and documented by Sphinx

    @property
    def axes(self):
        """An array of the :class:`matplotlib.axes.Axes` objects in the grid."""
        return self._axes

    @property
    def ax(self):
        """The :class:`matplotlib.axes.Axes` when no faceting variables are assigned."""
        if self.axes.shape == (1, 1):
            return self.axes[0, 0]
        else:
            err = (
                "Use the `.axes` attribute when facet variables are assigned."
            )
            raise AttributeError(err)

    @property
    def axes_dict(self):
        """A mapping of facet names to corresponding :class:`matplotlib.axes.Axes`.

        If only one of ``row`` or ``col`` is assigned, each key is a string
        representing a level of that variable. If both facet dimensions are
        assigned, each key is a ``({row_level}, {col_level})`` tuple.

        """
        return self._axes_dict

    # ------ Private properties, that require some computation to get

    @property
    def _inner_axes(self):
        """Return a flat array of the inner axes."""
        if self._col_wrap is None:
            return self.axes[:-1, 1:].flat
        else:
            axes = []
            n_empty = self._nrow * self._ncol - self._n_facets
            for i, ax in enumerate(self.axes):
                append = (
                    i % self._ncol
                    and i < (self._ncol * (self._nrow - 1))
                    and i < (self._ncol * (self._nrow - 1) - n_empty)
                )
                if append:
                    axes.append(ax)
            return np.array(axes, object).flat

    @property
    def _left_axes(self):
        """Return a flat array of the left column of axes."""
        if self._col_wrap is None:
            return self.axes[:, 0].flat
        else:
            axes = []
            for i, ax in enumerate(self.axes):
                if not i % self._ncol:
                    axes.append(ax)
            return np.array(axes, object).flat

    @property
    def _not_left_axes(self):
        """Return a flat array of axes that aren't on the left column."""
        if self._col_wrap is None:
            return self.axes[:, 1:].flat
        else:
            axes = []
            for i, ax in enumerate(self.axes):
                if i % self._ncol:
                    axes.append(ax)
            return np.array(axes, object).flat

    @property
    def _bottom_axes(self):
        """Return a flat array of the bottom row of axes."""
        if self._col_wrap is None:
            return self.axes[-1, :].flat
        else:
            axes = []
            n_empty = self._nrow * self._ncol - self._n_facets
            for i, ax in enumerate(self.axes):
                append = (
                    i >= (self._ncol * (self._nrow - 1))
                    or i >= (self._ncol * (self._nrow - 1) - n_empty)
                )
                if append:
                    axes.append(ax)
            return np.array(axes, object).flat

    @property
    def _not_bottom_axes(self):
        """Return a flat array of axes that aren't on the bottom row."""
        if self._col_wrap is None:
            return self.axes[:-1, :].flat
        else:
            axes = []
            n_empty = self._nrow * self._ncol - self._n_facets
            for i, ax in enumerate(self.axes):
                append = (
                    i < (self._ncol * (self._nrow - 1))
                    and i < (self._ncol * (self._nrow - 1) - n_empty)
                )
                if append:
                    axes.append(ax)
            return np.array(axes, object).flat


class PairGrid(Grid):
    """Subplot grid for plotting pairwise relationships in a dataset.

    This object maps each variable in a dataset onto a column and row in a
    grid of multiple axes. Different axes-level plotting functions can be
    used to draw bivariate plots in the upper and lower triangles, and the
    marginal distribution of each variable can be shown on the diagonal.

    Several different common plots can be generated in a single line using
    :func:`pairplot`. Use :class:`PairGrid` when you need more flexibility.

    See the :ref:`tutorial <grid_tutorial>` for more information.

    """
    def __init__(
        self, data, *, hue=None, vars=None, x_vars=None, y_vars=None,
        hue_order=None, palette=None, hue_kws=None, corner=False, diag_sharey=True,
        height=2.5, aspect=1, layout_pad=.5, despine=True, dropna=False,
    ):
        """Initialize the plot figure and PairGrid object.

        Parameters
        ----------
        data : DataFrame
            Tidy (long-form) dataframe where each column is a variable and
            each row is an observation.
        hue : string (variable name)
            Variable in ``data`` to map plot aspects to different colors. This
            variable will be excluded from the default x and y variables.
        vars : list of variable names
            Variables within ``data`` to use, otherwise use every column with
            a numeric datatype.
        {x, y}_vars : lists of variable names
            Variables within ``data`` to use separately for the rows and
            columns of the figure; i.e. to make a non-square plot.
        hue_order : list of strings
            Order for the levels of the hue variable in the palette
        palette : dict or seaborn color palette
            Set of colors for mapping the ``hue`` variable. If a dict, keys
            should be values  in the ``hue`` variable.
        hue_kws : dictionary of param -> list of values mapping
            Other keyword arguments to insert into the plotting call to let
            other plot attributes vary across levels of the hue variable (e.g.
            the markers in a scatterplot).
        corner : bool
            If True, don't add axes to the upper (off-diagonal) triangle of the
            grid, making this a "corner" plot.
        height : scalar
            Height (in inches) of each facet.
        aspect : scalar
            Aspect * height gives the width (in inches) of each facet.
        layout_pad : scalar
            Padding between axes; passed to ``fig.tight_layout``.
        despine : boolean
            Remove the top and right spines from the plots.
        dropna : boolean
            Drop missing values from the data before plotting.

        See Also
        --------
        pairplot : Easily drawing common uses of :class:`PairGrid`.
        FacetGrid : Subplot grid for plotting conditional relationships.

        Examples
        --------

        .. include:: ../docstrings/PairGrid.rst

        """

        super().__init__()

        # Sort out the variables that define the grid
        numeric_cols = self._find_numeric_cols(data)
        if hue in numeric_cols:
            numeric_cols.remove(hue)
        if vars is not None:
            x_vars = list(vars)
            y_vars = list(vars)
        if x_vars is None:
            x_vars = numeric_cols
        if y_vars is None:
            y_vars = numeric_cols

        if np.isscalar(x_vars):
            x_vars = [x_vars]
        if np.isscalar(y_vars):
            y_vars = [y_vars]

        self.x_vars = x_vars = list(x_vars)
        self.y_vars = y_vars = list(y_vars)
        self.square_grid = self.x_vars == self.y_vars

        if not x_vars:
            raise ValueError("No variables found for grid columns.")
        if not y_vars:
            raise ValueError("No variables found for grid rows.")

        # Create the figure and the array of subplots
        figsize = len(x_vars) * height * aspect, len(y_vars) * height

        with _disable_autolayout():
            fig = plt.figure(figsize=figsize)

        axes = fig.subplots(len(y_vars), len(x_vars),
                            sharex="col", sharey="row",
                            squeeze=False)

        # Possibly remove upper axes to make a corner grid
        # Note: setting up the axes is usually the most time-intensive part
        # of using the PairGrid. We are foregoing the speed improvement that
        # we would get by just not setting up the hidden axes so that we can
        # avoid implementing fig.subplots ourselves. But worth thinking about.
        self._corner = corner
        if corner:
            hide_indices = np.triu_indices_from(axes, 1)
            for i, j in zip(*hide_indices):
                axes[i, j].remove()
                axes[i, j] = None

        self._figure = fig
        self.axes = axes
        self.data = data

        # Save what we are going to do with the diagonal
        self.diag_sharey = diag_sharey
        self.diag_vars = None
        self.diag_axes = None

        self._dropna = dropna

        # Label the axes
        self._add_axis_labels()

        # Sort out the hue variable
        self._hue_var = hue
        if hue is None:
            self.hue_names = hue_order = ["_nolegend_"]
            self.hue_vals = pd.Series(["_nolegend_"] * len(data),
                                      index=data.index)
        else:
            # We need hue_order and hue_names because the former is used to control
            # the order of drawing and the latter is used to control the order of
            # the legend. hue_names can become string-typed while hue_order must
            # retain the type of the input data. This is messy but results from
            # the fact that PairGrid can implement the hue-mapping logic itself
            # (and was originally written exclusively that way) but now can delegate
            # to the axes-level functions, while always handling legend creation.
            # See GH2307
            hue_names = hue_order = categorical_order(data[hue], hue_order)
            if dropna:
                # Filter NA from the list of unique hue names
                hue_names = list(filter(pd.notnull, hue_names))
            self.hue_names = hue_names
            self.hue_vals = data[hue]

        # Additional dict of kwarg -> list of values for mapping the hue var
        self.hue_kws = hue_kws if hue_kws is not None else {}

        self._orig_palette = palette
        self._hue_order = hue_order
        self.palette = self._get_palette(data, hue, hue_order, palette)
        self._legend_data = {}

        # Make the plot look nice
        for ax in axes[:-1, :].flat:
            if ax is None:
                continue
            for label in ax.get_xticklabels():
                label.set_visible(False)
            ax.xaxis.offsetText.set_visible(False)
            ax.xaxis.label.set_visible(False)

        for ax in axes[:, 1:].flat:
            if ax is None:
                continue
            for label in ax.get_yticklabels():
                label.set_visible(False)
            ax.yaxis.offsetText.set_visible(False)
            ax.yaxis.label.set_visible(False)

        self._tight_layout_rect = [.01, .01, .99, .99]
        self._tight_layout_pad = layout_pad
        self._despine = despine
        if despine:
            utils.despine(fig=fig)
        self.tight_layout(pad=layout_pad)

    def map(self, func, **kwargs):
        """Plot with the same function in every subplot.

        Parameters
        ----------
        func : callable plotting function
            Must take x, y arrays as positional arguments and draw onto the
            "currently active" matplotlib Axes. Also needs to accept kwargs
            called ``color`` and  ``label``.

        """
        row_indices, col_indices = np.indices(self.axes.shape)
        indices = zip(row_indices.flat, col_indices.flat)
        self._map_bivariate(func, indices, **kwargs)

        return self

    def map_lower(self, func, **kwargs):
        """Plot with a bivariate function on the lower diagonal subplots.

        Parameters
        ----------
        func : callable plotting function
            Must take x, y arrays as positional arguments and draw onto the
            "currently active" matplotlib Axes. Also needs to accept kwargs
            called ``color`` and  ``label``.

        """
        indices = zip(*np.tril_indices_from(self.axes, -1))
        self._map_bivariate(func, indices, **kwargs)
        return self

    def map_upper(self, func, **kwargs):
        """Plot with a bivariate function on the upper diagonal subplots.

        Parameters
        ----------
        func : callable plotting function
            Must take x, y arrays as positional arguments and draw onto the
            "currently active" matplotlib Axes. Also needs to accept kwargs
            called ``color`` and  ``label``.

        """
        indices = zip(*np.triu_indices_from(self.axes, 1))
        self._map_bivariate(func, indices, **kwargs)
        return self

    def map_offdiag(self, func, **kwargs):
        """Plot with a bivariate function on the off-diagonal subplots.

        Parameters
        ----------
        func : callable plotting function
            Must take x, y arrays as positional arguments and draw onto the
            "currently active" matplotlib Axes. Also needs to accept kwargs
            called ``color`` and  ``label``.

        """
        if self.square_grid:
            self.map_lower(func, **kwargs)
            if not self._corner:
                self.map_upper(func, **kwargs)
        else:
            indices = []
            for i, (y_var) in enumerate(self.y_vars):
                for j, (x_var) in enumerate(self.x_vars):
                    if x_var != y_var:
                        indices.append((i, j))
            self._map_bivariate(func, indices, **kwargs)
        return self

    def map_diag(self, func, **kwargs):
        """Plot with a univariate function on each diagonal subplot.

        Parameters
        ----------
        func : callable plotting function
            Must take an x array as a positional argument and draw onto the
            "currently active" matplotlib Axes. Also needs to accept kwargs
            called ``color`` and  ``label``.

        """
        # Add special diagonal axes for the univariate plot
        if self.diag_axes is None:
            diag_vars = []
            diag_axes = []
            for i, y_var in enumerate(self.y_vars):
                for j, x_var in enumerate(self.x_vars):
                    if x_var == y_var:

                        # Make the density axes
                        diag_vars.append(x_var)
                        ax = self.axes[i, j]
                        diag_ax = ax.twinx()
                        diag_ax.set_axis_off()
                        diag_axes.append(diag_ax)

                        # Work around matplotlib bug
                        # https://github.com/matplotlib/matplotlib/issues/15188
                        if not plt.rcParams.get("ytick.left", True):
                            for tick in ax.yaxis.majorTicks:
                                tick.tick1line.set_visible(False)

                        # Remove main y axis from density axes in a corner plot
                        if self._corner:
                            ax.yaxis.set_visible(False)
                            if self._despine:
                                utils.despine(ax=ax, left=True)
                            # TODO add optional density ticks (on the right)
                            # when drawing a corner plot?

            if self.diag_sharey and diag_axes:
                for ax in diag_axes[1:]:
                    share_axis(diag_axes[0], ax, "y")

            self.diag_vars = np.array(diag_vars, np.object_)
            self.diag_axes = np.array(diag_axes, np.object_)

        if "hue" not in signature(func).parameters:
            return self._map_diag_iter_hue(func, **kwargs)

        # Loop over diagonal variables and axes, making one plot in each
        for var, ax in zip(self.diag_vars, self.diag_axes):

            plot_kwargs = kwargs.copy()
            if str(func.__module__).startswith("seaborn"):
                plot_kwargs["ax"] = ax
            else:
                plt.sca(ax)

            vector = self.data[var]
            if self._hue_var is not None:
                hue = self.data[self._hue_var]
            else:
                hue = None

            if self._dropna:
                not_na = vector.notna()
                if hue is not None:
                    not_na &= hue.notna()
                vector = vector[not_na]
                if hue is not None:
                    hue = hue[not_na]

            plot_kwargs.setdefault("hue", hue)
            plot_kwargs.setdefault("hue_order", self._hue_order)
            plot_kwargs.setdefault("palette", self._orig_palette)
            func(x=vector, **plot_kwargs)
            ax.legend_ = None

        self._add_axis_labels()
        return self

    def _map_diag_iter_hue(self, func, **kwargs):
        """Put marginal plot on each diagonal axes, iterating over hue."""
        # Plot on each of the diagonal axes
        fixed_color = kwargs.pop("color", None)

        for var, ax in zip(self.diag_vars, self.diag_axes):
            hue_grouped = self.data[var].groupby(self.hue_vals)

            plot_kwargs = kwargs.copy()
            if str(func.__module__).startswith("seaborn"):
                plot_kwargs["ax"] = ax
            else:
                plt.sca(ax)

            for k, label_k in enumerate(self._hue_order):

                # Attempt to get data for this level, allowing for empty
                try:
                    data_k = hue_grouped.get_group(label_k)
                except KeyError:
                    data_k = pd.Series([], dtype=float)

                if fixed_color is None:
                    color = self.palette[k]
                else:
                    color = fixed_color

                if self._dropna:
                    data_k = utils.remove_na(data_k)

                if str(func.__module__).startswith("seaborn"):
                    func(x=data_k, label=label_k, color=color, **plot_kwargs)
                else:
                    func(data_k, label=label_k, color=color, **plot_kwargs)

        self._add_axis_labels()

        return self

    def _map_bivariate(self, func, indices, **kwargs):
        """Draw a bivariate plot on the indicated axes."""
        # This is a hack to handle the fact that new distribution plots don't add
        # their artists onto the axes. This is probably superior in general, but
        # we'll need a better way to handle it in the axisgrid functions.
        from .distributions import histplot, kdeplot
        if func is histplot or func is kdeplot:
            self._extract_legend_handles = True

        kws = kwargs.copy()  # Use copy as we insert other kwargs
        for i, j in indices:
            x_var = self.x_vars[j]
            y_var = self.y_vars[i]
            ax = self.axes[i, j]
            if ax is None:  # i.e. we are in corner mode
                continue
            self._plot_bivariate(x_var, y_var, ax, func, **kws)
        self._add_axis_labels()

        if "hue" in signature(func).parameters:
            self.hue_names = list(self._legend_data)

    def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs):
        """Draw a bivariate plot on the specified axes."""
        if "hue" not in signature(func).parameters:
            self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs)
            return

        kwargs = kwargs.copy()
        if str(func.__module__).startswith("seaborn"):
            kwargs["ax"] = ax
        else:
            plt.sca(ax)

        if x_var == y_var:
            axes_vars = [x_var]
        else:
            axes_vars = [x_var, y_var]

        if self._hue_var is not None and self._hue_var not in axes_vars:
            axes_vars.append(self._hue_var)

        data = self.data[axes_vars]
        if self._dropna:
            data = data.dropna()

        x = data[x_var]
        y = data[y_var]
        if self._hue_var is None:
            hue = None
        else:
            hue = data.get(self._hue_var)

        if "hue" not in kwargs:
            kwargs.update({
                "hue": hue, "hue_order": self._hue_order, "palette": self._orig_palette,
            })
        func(x=x, y=y, **kwargs)

        self._update_legend_data(ax)

    def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs):
        """Draw a bivariate plot while iterating over hue subsets."""
        kwargs = kwargs.copy()
        if str(func.__module__).startswith("seaborn"):
            kwargs["ax"] = ax
        else:
            plt.sca(ax)

        if x_var == y_var:
            axes_vars = [x_var]
        else:
            axes_vars = [x_var, y_var]

        hue_grouped = self.data.groupby(self.hue_vals)
        for k, label_k in enumerate(self._hue_order):

            kws = kwargs.copy()

            # Attempt to get data for this level, allowing for empty
            try:
                data_k = hue_grouped.get_group(label_k)
            except KeyError:
                data_k = pd.DataFrame(columns=axes_vars,
                                      dtype=float)

            if self._dropna:
                data_k = data_k[axes_vars].dropna()

            x = data_k[x_var]
            y = data_k[y_var]

            for kw, val_list in self.hue_kws.items():
                kws[kw] = val_list[k]
            kws.setdefault("color", self.palette[k])
            if self._hue_var is not None:
                kws["label"] = label_k

            if str(func.__module__).startswith("seaborn"):
                func(x=x, y=y, **kws)
            else:
                func(x, y, **kws)

        self._update_legend_data(ax)

    def _add_axis_labels(self):
        """Add labels to the left and bottom Axes."""
        for ax, label in zip(self.axes[-1, :], self.x_vars):
            ax.set_xlabel(label)
        for ax, label in zip(self.axes[:, 0], self.y_vars):
            ax.set_ylabel(label)

    def _find_numeric_cols(self, data):
        """Find which variables in a DataFrame are numeric."""
        numeric_cols = []
        for col in data:
            if variable_type(data[col]) == "numeric":
                numeric_cols.append(col)
        return numeric_cols


class JointGrid(_BaseGrid):
    """Grid for drawing a bivariate plot with marginal univariate plots.

    Many plots can be drawn by using the figure-level interface :func:`jointplot`.
    Use this class directly when you need more flexibility.

    """

    def __init__(
        self, data=None, *,
        x=None, y=None, hue=None,
        height=6, ratio=5, space=.2,
        palette=None, hue_order=None, hue_norm=None,
        dropna=False, xlim=None, ylim=None, marginal_ticks=False,
    ):

        # Set up the subplot grid
        f = plt.figure(figsize=(height, height))
        gs = plt.GridSpec(ratio + 1, ratio + 1)

        ax_joint = f.add_subplot(gs[1:, :-1])
        ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
        ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)

        self._figure = f
        self.ax_joint = ax_joint
        self.ax_marg_x = ax_marg_x
        self.ax_marg_y = ax_marg_y

        # Turn off tick visibility for the measure axis on the marginal plots
        plt.setp(ax_marg_x.get_xticklabels(), visible=False)
        plt.setp(ax_marg_y.get_yticklabels(), visible=False)
        plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
        plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)

        # Turn off the ticks on the density axis for the marginal plots
        if not marginal_ticks:
            plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
            plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
            plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
            plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
            plt.setp(ax_marg_x.get_yticklabels(), visible=False)
            plt.setp(ax_marg_y.get_xticklabels(), visible=False)
            plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
            plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
            ax_marg_x.yaxis.grid(False)
            ax_marg_y.xaxis.grid(False)

        # Process the input variables
        p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue))
        plot_data = p.plot_data.loc[:, p.plot_data.notna().any()]

        # Possibly drop NA
        if dropna:
            plot_data = plot_data.dropna()

        def get_var(var):
            vector = plot_data.get(var, None)
            if vector is not None:
                vector = vector.rename(p.variables.get(var, None))
            return vector

        self.x = get_var("x")
        self.y = get_var("y")
        self.hue = get_var("hue")

        for axis in "xy":
            name = p.variables.get(axis, None)
            if name is not None:
                getattr(ax_joint, f"set_{axis}label")(name)

        if xlim is not None:
            ax_joint.set_xlim(xlim)
        if ylim is not None:
            ax_joint.set_ylim(ylim)

        # Store the semantic mapping parameters for axes-level functions
        self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm)

        # Make the grid look nice
        utils.despine(f)
        if not marginal_ticks:
            utils.despine(ax=ax_marg_x, left=True)
            utils.despine(ax=ax_marg_y, bottom=True)
        for axes in [ax_marg_x, ax_marg_y]:
            for axis in [axes.xaxis, axes.yaxis]:
                axis.label.set_visible(False)
        f.tight_layout()
        f.subplots_adjust(hspace=space, wspace=space)

    def _inject_kwargs(self, func, kws, params):
        """Add params to kws if they are accepted by func."""
        func_params = signature(func).parameters
        for key, val in params.items():
            if key in func_params:
                kws.setdefault(key, val)

    def plot(self, joint_func, marginal_func, **kwargs):
        """Draw the plot by passing functions for joint and marginal axes.

        This method passes the ``kwargs`` dictionary to both functions. If you
        need more control, call :meth:`JointGrid.plot_joint` and
        :meth:`JointGrid.plot_marginals` directly with specific parameters.

        Parameters
        ----------
        joint_func, marginal_func : callables
            Functions to draw the bivariate and univariate plots. See methods
            referenced above for information about the required characteristics
            of these functions.
        kwargs
            Additional keyword arguments are passed to both functions.

        Returns
        -------
        :class:`JointGrid` instance
            Returns ``self`` for easy method chaining.

        """
        self.plot_marginals(marginal_func, **kwargs)
        self.plot_joint(joint_func, **kwargs)
        return self

    def plot_joint(self, func, **kwargs):
        """Draw a bivariate plot on the joint axes of the grid.

        Parameters
        ----------
        func : plotting callable
            If a seaborn function, it should accept ``x`` and ``y``. Otherwise,
            it must accept ``x`` and ``y`` vectors of data as the first two
            positional arguments, and it must plot on the "current" axes.
            If ``hue`` was defined in the class constructor, the function must
            accept ``hue`` as a parameter.
        kwargs
            Keyword argument are passed to the plotting function.

        Returns
        -------
        :class:`JointGrid` instance
            Returns ``self`` for easy method chaining.

        """
        kwargs = kwargs.copy()
        if str(func.__module__).startswith("seaborn"):
            kwargs["ax"] = self.ax_joint
        else:
            plt.sca(self.ax_joint)
        if self.hue is not None:
            kwargs["hue"] = self.hue
            self._inject_kwargs(func, kwargs, self._hue_params)

        if str(func.__module__).startswith("seaborn"):
            func(x=self.x, y=self.y, **kwargs)
        else:
            func(self.x, self.y, **kwargs)

        return self

    def plot_marginals(self, func, **kwargs):
        """Draw univariate plots on each marginal axes.

        Parameters
        ----------
        func : plotting callable
            If a seaborn function, it should  accept ``x`` and ``y`` and plot
            when only one of them is defined. Otherwise, it must accept a vector
            of data as the first positional argument and determine its orientation
            using the ``vertical`` parameter, and it must plot on the "current" axes.
            If ``hue`` was defined in the class constructor, it must accept ``hue``
            as a parameter.
        kwargs
            Keyword argument are passed to the plotting function.

        Returns
        -------
        :class:`JointGrid` instance
            Returns ``self`` for easy method chaining.

        """
        seaborn_func = (
            str(func.__module__).startswith("seaborn")
            # deprecated distplot has a legacy API, special case it
            and not func.__name__ == "distplot"
        )
        func_params = signature(func).parameters
        kwargs = kwargs.copy()
        if self.hue is not None:
            kwargs["hue"] = self.hue
            self._inject_kwargs(func, kwargs, self._hue_params)

        if "legend" in func_params:
            kwargs.setdefault("legend", False)

        if "orientation" in func_params:
            # e.g. plt.hist
            orient_kw_x = {"orientation": "vertical"}
            orient_kw_y = {"orientation": "horizontal"}
        elif "vertical" in func_params:
            # e.g. sns.distplot (also how did this get backwards?)
            orient_kw_x = {"vertical": False}
            orient_kw_y = {"vertical": True}

        if seaborn_func:
            func(x=self.x, ax=self.ax_marg_x, **kwargs)
        else:
            plt.sca(self.ax_marg_x)
            func(self.x, **orient_kw_x, **kwargs)

        if seaborn_func:
            func(y=self.y, ax=self.ax_marg_y, **kwargs)
        else:
            plt.sca(self.ax_marg_y)
            func(self.y, **orient_kw_y, **kwargs)

        self.ax_marg_x.yaxis.get_label().set_visible(False)
        self.ax_marg_y.xaxis.get_label().set_visible(False)

        return self

    def refline(
        self, *, x=None, y=None, joint=True, marginal=True,
        color='.5', linestyle='--', **line_kws
    ):
        """Add a reference line(s) to joint and/or marginal axes.

        Parameters
        ----------
        x, y : numeric
            Value(s) to draw the line(s) at.
        joint, marginal : bools
            Whether to add the reference line(s) to the joint/marginal axes.
        color : :mod:`matplotlib color <matplotlib.colors>`
            Specifies the color of the reference line(s).
        linestyle : str
            Specifies the style of the reference line(s).
        line_kws : key, value mappings
            Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
            when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
            is not None.

        Returns
        -------
        :class:`JointGrid` instance
            Returns ``self`` for easy method chaining.

        """
        line_kws['color'] = color
        line_kws['linestyle'] = linestyle

        if x is not None:
            if joint:
                self.ax_joint.axvline(x, **line_kws)
            if marginal:
                self.ax_marg_x.axvline(x, **line_kws)

        if y is not None:
            if joint:
                self.ax_joint.axhline(y, **line_kws)
            if marginal:
                self.ax_marg_y.axhline(y, **line_kws)

        return self

    def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
        """Set axis labels on the bivariate axes.

        Parameters
        ----------
        xlabel, ylabel : strings
            Label names for the x and y variables.
        kwargs : key, value mappings
            Other keyword arguments are passed to the following functions:

            - :meth:`matplotlib.axes.Axes.set_xlabel`
            - :meth:`matplotlib.axes.Axes.set_ylabel`

        Returns
        -------
        :class:`JointGrid` instance
            Returns ``self`` for easy method chaining.

        """
        self.ax_joint.set_xlabel(xlabel, **kwargs)
        self.ax_joint.set_ylabel(ylabel, **kwargs)
        return self


JointGrid.__init__.__doc__ = """\
Set up the grid of subplots and store data internally for easy plotting.

Parameters
----------
{params.core.data}
{params.core.xy}
height : number
    Size of each side of the figure in inches (it will be square).
ratio : number
    Ratio of joint axes height to marginal axes height.
space : number
    Space between the joint and marginal axes
dropna : bool
    If True, remove missing observations before plotting.
{{x, y}}lim : pairs of numbers
    Set axis limits to these values before plotting.
marginal_ticks : bool
    If False, suppress ticks on the count/density axis of the marginal plots.
{params.core.hue}
    Note: unlike in :class:`FacetGrid` or :class:`PairGrid`, the axes-level
    functions must support ``hue`` to use it in :class:`JointGrid`.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}

See Also
--------
{seealso.jointplot}
{seealso.pairgrid}
{seealso.pairplot}

Examples
--------

.. include:: ../docstrings/JointGrid.rst

""".format(
    params=_param_docs,
    returns=_core_docs["returns"],
    seealso=_core_docs["seealso"],
)


def pairplot(
    data, *,
    hue=None, hue_order=None, palette=None,
    vars=None, x_vars=None, y_vars=None,
    kind="scatter", diag_kind="auto", markers=None,
    height=2.5, aspect=1, corner=False, dropna=False,
    plot_kws=None, diag_kws=None, grid_kws=None, size=None,
):
    """Plot pairwise relationships in a dataset.

    By default, this function will create a grid of Axes such that each numeric
    variable in ``data`` will by shared across the y-axes across a single row and
    the x-axes across a single column. The diagonal plots are treated
    differently: a univariate distribution plot is drawn to show the marginal
    distribution of the data in each column.

    It is also possible to show a subset of variables or plot different
    variables on the rows and columns.

    This is a high-level interface for :class:`PairGrid` that is intended to
    make it easy to draw a few common styles. You should use :class:`PairGrid`
    directly if you need more flexibility.

    Parameters
    ----------
    data : `pandas.DataFrame`
        Tidy (long-form) dataframe where each column is a variable and
        each row is an observation.
    hue : name of variable in ``data``
        Variable in ``data`` to map plot aspects to different colors.
    hue_order : list of strings
        Order for the levels of the hue variable in the palette
    palette : dict or seaborn color palette
        Set of colors for mapping the ``hue`` variable. If a dict, keys
        should be values  in the ``hue`` variable.
    vars : list of variable names
        Variables within ``data`` to use, otherwise use every column with
        a numeric datatype.
    {x, y}_vars : lists of variable names
        Variables within ``data`` to use separately for the rows and
        columns of the figure; i.e. to make a non-square plot.
    kind : {'scatter', 'kde', 'hist', 'reg'}
        Kind of plot to make.
    diag_kind : {'auto', 'hist', 'kde', None}
        Kind of plot for the diagonal subplots. If 'auto', choose based on
        whether or not ``hue`` is used.
    markers : single matplotlib marker code or list
        Either the marker to use for all scatterplot points or a list of markers
        with a length the same as the number of levels in the hue variable so that
        differently colored points will also have different scatterplot
        markers.
    height : scalar
        Height (in inches) of each facet.
    aspect : scalar
        Aspect * height gives the width (in inches) of each facet.
    corner : bool
        If True, don't add axes to the upper (off-diagonal) triangle of the
        grid, making this a "corner" plot.
    dropna : boolean
        Drop missing values from the data before plotting.
    {plot, diag, grid}_kws : dicts
        Dictionaries of keyword arguments. ``plot_kws`` are passed to the
        bivariate plotting function, ``diag_kws`` are passed to the univariate
        plotting function, and ``grid_kws`` are passed to the :class:`PairGrid`
        constructor.

    Returns
    -------
    grid : :class:`PairGrid`
        Returns the underlying :class:`PairGrid` instance for further tweaking.

    See Also
    --------
    PairGrid : Subplot grid for more flexible plotting of pairwise relationships.
    JointGrid : Grid for plotting joint and marginal distributions of two variables.

    Examples
    --------

    .. include:: ../docstrings/pairplot.rst

    """
    # Avoid circular import
    from .distributions import histplot, kdeplot

    # Handle deprecations
    if size is not None:
        height = size
        msg = ("The `size` parameter has been renamed to `height`; "
               "please update your code.")
        warnings.warn(msg, UserWarning)

    if not isinstance(data, pd.DataFrame):
        raise TypeError(
            f"'data' must be pandas DataFrame object, not: {type(data)}")

    plot_kws = {} if plot_kws is None else plot_kws.copy()
    diag_kws = {} if diag_kws is None else diag_kws.copy()
    grid_kws = {} if grid_kws is None else grid_kws.copy()

    # Resolve "auto" diag kind
    if diag_kind == "auto":
        if hue is None:
            diag_kind = "kde" if kind == "kde" else "hist"
        else:
            diag_kind = "hist" if kind == "hist" else "kde"

    # Set up the PairGrid
    grid_kws.setdefault("diag_sharey", diag_kind == "hist")
    grid = PairGrid(data, vars=vars, x_vars=x_vars, y_vars=y_vars, hue=hue,
                    hue_order=hue_order, palette=palette, corner=corner,
                    height=height, aspect=aspect, dropna=dropna, **grid_kws)

    # Add the markers here as PairGrid has figured out how many levels of the
    # hue variable are needed and we don't want to duplicate that process
    if markers is not None:
        if kind == "reg":
            # Needed until regplot supports style
            if grid.hue_names is None:
                n_markers = 1
            else:
                n_markers = len(grid.hue_names)
            if not isinstance(markers, list):
                markers = [markers] * n_markers
            if len(markers) != n_markers:
                raise ValueError("markers must be a singleton or a list of "
                                 "markers for each level of the hue variable")
            grid.hue_kws = {"marker": markers}
        elif kind == "scatter":
            if isinstance(markers, str):
                plot_kws["marker"] = markers
            elif hue is not None:
                plot_kws["style"] = data[hue]
                plot_kws["markers"] = markers

    # Draw the marginal plots on the diagonal
    diag_kws = diag_kws.copy()
    diag_kws.setdefault("legend", False)
    if diag_kind == "hist":
        grid.map_diag(histplot, **diag_kws)
    elif diag_kind == "kde":
        diag_kws.setdefault("fill", True)
        diag_kws.setdefault("warn_singular", False)
        grid.map_diag(kdeplot, **diag_kws)

    # Maybe plot on the off-diagonals
    if diag_kind is not None:
        plotter = grid.map_offdiag
    else:
        plotter = grid.map

    if kind == "scatter":
        from .relational import scatterplot  # Avoid circular import
        plotter(scatterplot, **plot_kws)
    elif kind == "reg":
        from .regression import regplot  # Avoid circular import
        plotter(regplot, **plot_kws)
    elif kind == "kde":
        from .distributions import kdeplot  # Avoid circular import
        plot_kws.setdefault("warn_singular", False)
        plotter(kdeplot, **plot_kws)
    elif kind == "hist":
        from .distributions import histplot  # Avoid circular import
        plotter(histplot, **plot_kws)

    # Add a legend
    if hue is not None:
        grid.add_legend()

    grid.tight_layout()

    return grid


def jointplot(
    data=None, *, x=None, y=None, hue=None, kind="scatter",
    height=6, ratio=5, space=.2, dropna=False, xlim=None, ylim=None,
    color=None, palette=None, hue_order=None, hue_norm=None, marginal_ticks=False,
    joint_kws=None, marginal_kws=None,
    **kwargs
):
    # Avoid circular imports
    from .relational import scatterplot
    from .regression import regplot, residplot
    from .distributions import histplot, kdeplot, _freedman_diaconis_bins

    if kwargs.pop("ax", None) is not None:
        msg = "Ignoring `ax`; jointplot is a figure-level function."
        warnings.warn(msg, UserWarning, stacklevel=2)

    # Set up empty default kwarg dicts
    joint_kws = {} if joint_kws is None else joint_kws.copy()
    joint_kws.update(kwargs)
    marginal_kws = {} if marginal_kws is None else marginal_kws.copy()

    # Handle deprecations of distplot-specific kwargs
    distplot_keys = [
        "rug", "fit", "hist_kws", "norm_hist" "hist_kws", "rug_kws",
    ]
    unused_keys = []
    for key in distplot_keys:
        if key in marginal_kws:
            unused_keys.append(key)
            marginal_kws.pop(key)
    if unused_keys and kind != "kde":
        msg = (
            "The marginal plotting function has changed to `histplot`,"
            " which does not accept the following argument(s): {}."
        ).format(", ".join(unused_keys))
        warnings.warn(msg, UserWarning)

    # Validate the plot kind
    plot_kinds = ["scatter", "hist", "hex", "kde", "reg", "resid"]
    _check_argument("kind", plot_kinds, kind)

    # Raise early if using `hue` with a kind that does not support it
    if hue is not None and kind in ["hex", "reg", "resid"]:
        msg = (
            f"Use of `hue` with `kind='{kind}'` is not currently supported."
        )
        raise ValueError(msg)

    # Make a colormap based off the plot color
    # (Currently used only for kind="hex")
    if color is None:
        color = "C0"
    color_rgb = mpl.colors.colorConverter.to_rgb(color)
    colors = [utils.set_hls_values(color_rgb, l=l)  # noqa
              for l in np.linspace(1, 0, 12)]
    cmap = blend_palette(colors, as_cmap=True)

    # Matplotlib's hexbin plot is not na-robust
    if kind == "hex":
        dropna = True

    # Initialize the JointGrid object
    grid = JointGrid(
        data=data, x=x, y=y, hue=hue,
        palette=palette, hue_order=hue_order, hue_norm=hue_norm,
        dropna=dropna, height=height, ratio=ratio, space=space,
        xlim=xlim, ylim=ylim, marginal_ticks=marginal_ticks,
    )

    if grid.hue is not None:
        marginal_kws.setdefault("legend", False)

    # Plot the data using the grid
    if kind.startswith("scatter"):

        joint_kws.setdefault("color", color)
        grid.plot_joint(scatterplot, **joint_kws)

        if grid.hue is None:
            marg_func = histplot
        else:
            marg_func = kdeplot
            marginal_kws.setdefault("warn_singular", False)
            marginal_kws.setdefault("fill", True)

        marginal_kws.setdefault("color", color)
        grid.plot_marginals(marg_func, **marginal_kws)

    elif kind.startswith("hist"):

        # TODO process pair parameters for bins, etc. and pass
        # to both joint and marginal plots

        joint_kws.setdefault("color", color)
        grid.plot_joint(histplot, **joint_kws)

        marginal_kws.setdefault("kde", False)
        marginal_kws.setdefault("color", color)

        marg_x_kws = marginal_kws.copy()
        marg_y_kws = marginal_kws.copy()

        pair_keys = "bins", "binwidth", "binrange"
        for key in pair_keys:
            if isinstance(joint_kws.get(key), tuple):
                x_val, y_val = joint_kws[key]
                marg_x_kws.setdefault(key, x_val)
                marg_y_kws.setdefault(key, y_val)

        histplot(data=data, x=x, hue=hue, **marg_x_kws, ax=grid.ax_marg_x)
        histplot(data=data, y=y, hue=hue, **marg_y_kws, ax=grid.ax_marg_y)

    elif kind.startswith("kde"):

        joint_kws.setdefault("color", color)
        joint_kws.setdefault("warn_singular", False)
        grid.plot_joint(kdeplot, **joint_kws)

        marginal_kws.setdefault("color", color)
        if "fill" in joint_kws:
            marginal_kws.setdefault("fill", joint_kws["fill"])

        grid.plot_marginals(kdeplot, **marginal_kws)

    elif kind.startswith("hex"):

        x_bins = min(_freedman_diaconis_bins(grid.x), 50)
        y_bins = min(_freedman_diaconis_bins(grid.y), 50)
        gridsize = int(np.mean([x_bins, y_bins]))

        joint_kws.setdefault("gridsize", gridsize)
        joint_kws.setdefault("cmap", cmap)
        grid.plot_joint(plt.hexbin, **joint_kws)

        marginal_kws.setdefault("kde", False)
        marginal_kws.setdefault("color", color)
        grid.plot_marginals(histplot, **marginal_kws)

    elif kind.startswith("reg"):

        marginal_kws.setdefault("color", color)
        marginal_kws.setdefault("kde", True)
        grid.plot_marginals(histplot, **marginal_kws)

        joint_kws.setdefault("color", color)
        grid.plot_joint(regplot, **joint_kws)

    elif kind.startswith("resid"):

        joint_kws.setdefault("color", color)
        grid.plot_joint(residplot, **joint_kws)

        x, y = grid.ax_joint.collections[0].get_offsets().T
        marginal_kws.setdefault("color", color)
        histplot(x=x, hue=hue, ax=grid.ax_marg_x, **marginal_kws)
        histplot(y=y, hue=hue, ax=grid.ax_marg_y, **marginal_kws)

    # Make the main axes active in the matplotlib state machine
    plt.sca(grid.ax_joint)

    return grid


jointplot.__doc__ = """\
Draw a plot of two variables with bivariate and univariate graphs.

This function provides a convenient interface to the :class:`JointGrid`
class, with several canned plot kinds. This is intended to be a fairly
lightweight wrapper; if you need more flexibility, you should use
:class:`JointGrid` directly.

Parameters
----------
{params.core.data}
{params.core.xy}
{params.core.hue}
    Semantic variable that is mapped to determine the color of plot elements.
kind : {{ "scatter" | "kde" | "hist" | "hex" | "reg" | "resid" }}
    Kind of plot to draw. See the examples for references to the underlying functions.
height : numeric
    Size of the figure (it will be square).
ratio : numeric
    Ratio of joint axes height to marginal axes height.
space : numeric
    Space between the joint and marginal axes
dropna : bool
    If True, remove observations that are missing from ``x`` and ``y``.
{{x, y}}lim : pairs of numbers
    Axis limits to set before plotting.
{params.core.color}
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
marginal_ticks : bool
    If False, suppress ticks on the count/density axis of the marginal plots.
{{joint, marginal}}_kws : dicts
    Additional keyword arguments for the plot components.
kwargs
    Additional keyword arguments are passed to the function used to
    draw the plot on the joint Axes, superseding items in the
    ``joint_kws`` dictionary.

Returns
-------
{returns.jointgrid}

See Also
--------
{seealso.jointgrid}
{seealso.pairgrid}
{seealso.pairplot}

Examples
--------

.. include:: ../docstrings/jointplot.rst

""".format(
    params=_param_docs,
    returns=_core_docs["returns"],
    seealso=_core_docs["seealso"],
)
