import numpy as np
import matplotlib as mpl
from seaborn.external.version import Version


def MarkerStyle(marker=None, fillstyle=None):
    """
    Allow MarkerStyle to accept a MarkerStyle object as parameter.

    Supports matplotlib < 3.3.0
    https://github.com/matplotlib/matplotlib/pull/16692

    """
    if isinstance(marker, mpl.markers.MarkerStyle):
        if fillstyle is None:
            return marker
        else:
            marker = marker.get_marker()
    return mpl.markers.MarkerStyle(marker, fillstyle)


def norm_from_scale(scale, norm):
    """Produce a Normalize object given a Scale and min/max domain limits."""
    # This is an internal maplotlib function that simplifies things to access
    # It is likely to become part of the matplotlib API at some point:
    # https://github.com/matplotlib/matplotlib/issues/20329
    if isinstance(norm, mpl.colors.Normalize):
        return norm

    if scale is None:
        return None

    if norm is None:
        vmin = vmax = None
    else:
        vmin, vmax = norm  # TODO more helpful error if this fails?

    class ScaledNorm(mpl.colors.Normalize):

        def __call__(self, value, clip=None):
            # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py
            # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE
            value, is_scalar = self.process_value(value)
            self.autoscale_None(value)
            if self.vmin > self.vmax:
                raise ValueError("vmin must be less or equal to vmax")
            if self.vmin == self.vmax:
                return np.full_like(value, 0)
            if clip is None:
                clip = self.clip
            if clip:
                value = np.clip(value, self.vmin, self.vmax)
            # ***** Seaborn changes start ****
            t_value = self.transform(value).reshape(np.shape(value))
            t_vmin, t_vmax = self.transform([self.vmin, self.vmax])
            # ***** Seaborn changes end *****
            if not np.isfinite([t_vmin, t_vmax]).all():
                raise ValueError("Invalid vmin or vmax")
            t_value -= t_vmin
            t_value /= (t_vmax - t_vmin)
            t_value = np.ma.masked_invalid(t_value, copy=False)
            return t_value[0] if is_scalar else t_value

    new_norm = ScaledNorm(vmin, vmax)
    new_norm.transform = scale.get_transform().transform

    return new_norm


def scale_factory(scale, axis, **kwargs):
    """
    Backwards compatability for creation of independent scales.

    Matplotlib scales require an Axis object for instantiation on < 3.4.
    But the axis is not used, aside from extraction of the axis_name in LogScale.

    """
    modify_transform = False
    if Version(mpl.__version__) < Version("3.4"):
        if axis[0] in "xy":
            modify_transform = True
            axis = axis[0]
            base = kwargs.pop("base", None)
            if base is not None:
                kwargs[f"base{axis}"] = base
            nonpos = kwargs.pop("nonpositive", None)
            if nonpos is not None:
                kwargs[f"nonpos{axis}"] = nonpos

    if isinstance(scale, str):
        class Axis:
            axis_name = axis
        axis = Axis()

    scale = mpl.scale.scale_factory(scale, axis, **kwargs)

    if modify_transform:
        transform = scale.get_transform()
        transform.base = kwargs.get("base", 10)
        if kwargs.get("nonpositive") == "mask":
            # Setting a private attribute, but we only get here
            # on an old matplotlib, so this won't break going forwards
            transform._clip = False

    return scale


def set_scale_obj(ax, axis, scale):
    """Handle backwards compatability with setting matplotlib scale."""
    if Version(mpl.__version__) < Version("3.4"):
        # The ability to pass a BaseScale instance to Axes.set_{}scale was added
        # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089
        # Workaround: use the scale name, which is restrictive only if the user
        # wants to define a custom scale; they'll need to update the registry too.
        if scale.name is None:
            # Hack to support our custom Formatter-less CatScale
            return
        method = getattr(ax, f"set_{axis}scale")
        kws = {}
        if scale.name == "function":
            trans = scale.get_transform()
            kws["functions"] = (trans._forward, trans._inverse)
        method(scale.name, **kws)
        axis_obj = getattr(ax, f"{axis}axis")
        scale.set_default_locators_and_formatters(axis_obj)
    else:
        ax.set(**{f"{axis}scale": scale})


def get_colormap(name):
    """Handle changes to matplotlib colormap interface in 3.6."""
    try:
        return mpl.colormaps[name]
    except AttributeError:
        return mpl.cm.get_cmap(name)


def register_colormap(name, cmap):
    """Handle changes to matplotlib colormap interface in 3.6."""
    try:
        if name not in mpl.colormaps:
            mpl.colormaps.register(cmap, name=name)
    except AttributeError:
        mpl.cm.register_cmap(name, cmap)


def set_layout_engine(fig, engine):
    """Handle changes to auto layout engine interface in 3.6"""
    if hasattr(fig, "set_layout_engine"):
        fig.set_layout_engine(engine)
    else:
        if engine == "tight":
            fig.set_tight_layout(True)
        elif engine == "constrained":
            fig.set_constrained_layout(True)


def share_axis(ax0, ax1, which):
    """Handle changes to post-hoc axis sharing."""
    if Version(mpl.__version__) < Version("3.5.0"):
        group = getattr(ax0, f"get_shared_{which}_axes")()
        group.join(ax1, ax0)
    else:
        getattr(ax1, f"share{which}")(ax0)
