# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""Contains the InputSpec class."""

import tensorflow.compat.v2 as tf

from keras import backend

# isort: off
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export


@keras_export(
    "keras.layers.InputSpec",
    v1=["keras.layers.InputSpec", "keras.__internal__.legacy.layers.InputSpec"],
)
@tf_export(v1=["layers.InputSpec"])
class InputSpec:
    """Specifies the rank, dtype and shape of every input to a layer.

    Layers can expose (if appropriate) an `input_spec` attribute:
    an instance of `InputSpec`, or a nested structure of `InputSpec` instances
    (one per input tensor). These objects enable the layer to run input
    compatibility checks for input structure, input rank, input shape, and
    input dtype.

    A None entry in a shape is compatible with any dimension,
    a None shape is compatible with any shape.

    Args:
      dtype: Expected DataType of the input.
      shape: Shape tuple, expected shape of the input
        (may include None for unchecked axes). Includes the batch size.
      ndim: Integer, expected rank of the input.
      max_ndim: Integer, maximum rank of the input.
      min_ndim: Integer, minimum rank of the input.
      axes: Dictionary mapping integer axes to
        a specific dimension value.
      allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
        as the last axis of the input is 1, as well as inputs of rank N-1
        as long as the last axis of the spec is 1.
      name: Expected key corresponding to this input when passing data as
        a dictionary.

    Example:

    ```python
    class MyLayer(Layer):
        def __init__(self):
            super(MyLayer, self).__init__()
            # The layer will accept inputs with
            # shape (?, 28, 28) & (?, 28, 28, 1)
            # and raise an appropriate error message otherwise.
            self.input_spec = InputSpec(
                shape=(None, 28, 28, 1),
                allow_last_axis_squeeze=True)
    ```
    """

    def __init__(
        self,
        dtype=None,
        shape=None,
        ndim=None,
        max_ndim=None,
        min_ndim=None,
        axes=None,
        allow_last_axis_squeeze=False,
        name=None,
    ):
        self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
        shape = tf.TensorShape(shape)
        if shape.rank is None:
            shape = None
        else:
            shape = tuple(shape.as_list())
        if shape is not None:
            self.ndim = len(shape)
            self.shape = shape
        else:
            self.ndim = ndim
            self.shape = None
        self.max_ndim = max_ndim
        self.min_ndim = min_ndim
        self.name = name
        self.allow_last_axis_squeeze = allow_last_axis_squeeze
        try:
            axes = axes or {}
            self.axes = {int(k): axes[k] for k in axes}
        except (ValueError, TypeError):
            raise TypeError(
                "Argument `axes` must be a dict with integer keys. "
                f"Received: axes={axes}"
            )

        if self.axes and (self.ndim is not None or self.max_ndim is not None):
            max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
            max_axis = max(self.axes)
            if max_axis > max_dim:
                raise ValueError(
                    "Axis {} is greater than the maximum "
                    "allowed value: {}".format(max_axis, max_dim)
                )

    def __repr__(self):
        spec = [
            ("dtype=" + str(self.dtype)) if self.dtype else "",
            ("shape=" + str(self.shape)) if self.shape else "",
            ("ndim=" + str(self.ndim)) if self.ndim else "",
            ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "",
            ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "",
            ("axes=" + str(self.axes)) if self.axes else "",
        ]
        return "InputSpec(%s)" % ", ".join(x for x in spec if x)

    def get_config(self):
        return {
            "dtype": self.dtype,
            "shape": self.shape,
            "ndim": self.ndim,
            "max_ndim": self.max_ndim,
            "min_ndim": self.min_ndim,
            "axes": self.axes,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)


def to_tensor_shape(spec):
    """Returns a tf.TensorShape object that matches the shape specifications.

    If the InputSpec's shape or ndim is defined, this method will return a fully
    or partially-known shape. Otherwise, the returned TensorShape is None.

    Args:
      spec: an InputSpec object.

    Returns:
      a tf.TensorShape object
    """
    if spec.ndim is None and spec.shape is None:
        return tf.TensorShape(None)
    elif spec.shape is not None:
        return tf.TensorShape(spec.shape)
    else:
        shape = [None] * spec.ndim
        for a in spec.axes:
            shape[a] = spec.axes[a]  # Assume that axes is defined
        return tf.TensorShape(shape)


def assert_input_compatibility(input_spec, inputs, layer_name):
    """Checks compatibility between the layer and provided inputs.

    This checks that the tensor(s) `inputs` verify the input assumptions
    of a layer (if any). If not, a clear and actional exception gets raised.

    Args:
        input_spec: An InputSpec instance, list of InputSpec instances, a nested
            structure of InputSpec instances, or None.
        inputs: Input tensor, list of input tensors, or a nested structure of
            input tensors.
        layer_name: String, name of the layer (for error message formatting).

    Raises:
        ValueError: in case of mismatch between
            the provided inputs and the expectations of the layer.
    """
    if not input_spec:
        return

    input_spec = tf.nest.flatten(input_spec)
    if isinstance(inputs, dict):
        # Flatten `inputs` by reference order if input spec names are provided
        names = [spec.name for spec in input_spec]
        if all(names):
            list_inputs = []
            for name in names:
                if name not in inputs:
                    raise ValueError(
                        f'Missing data for input "{name}". '
                        "You passed a data dictionary with keys "
                        f"{list(inputs.keys())}. "
                        f"Expected the following keys: {names}"
                    )
                list_inputs.append(inputs[name])
            inputs = list_inputs

    inputs = tf.nest.flatten(inputs)
    for x in inputs:
        # Having a shape/dtype is the only commonality of the various
        # tensor-like objects that may be passed. The most common kind of
        # invalid type we are guarding for is a Layer instance (Functional API),
        # which does not have a `shape` attribute.
        if not hasattr(x, "shape"):
            raise TypeError(f"Inputs to a layer should be tensors. Got: {x}")

    if len(inputs) != len(input_spec):
        raise ValueError(
            f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
            f" but it received {len(inputs)} input tensors. "
            f"Inputs received: {inputs}"
        )
    for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
        if spec is None:
            continue

        shape = tf.TensorShape(x.shape)
        if shape.rank is None:
            return
        # Check ndim.
        if spec.ndim is not None and not spec.allow_last_axis_squeeze:
            ndim = shape.rank
            if ndim != spec.ndim:
                raise ValueError(
                    f'Input {input_index} of layer "{layer_name}" '
                    "is incompatible with the layer: "
                    f"expected ndim={spec.ndim}, found ndim={ndim}. "
                    f"Full shape received: {tuple(shape)}"
                )
        if spec.max_ndim is not None:
            ndim = x.shape.rank
            if ndim is not None and ndim > spec.max_ndim:
                raise ValueError(
                    f'Input {input_index} of layer "{layer_name}" '
                    "is incompatible with the layer: "
                    f"expected max_ndim={spec.max_ndim}, "
                    f"found ndim={ndim}"
                )
        if spec.min_ndim is not None:
            ndim = x.shape.rank
            if ndim is not None and ndim < spec.min_ndim:
                raise ValueError(
                    f'Input {input_index} of layer "{layer_name}" '
                    "is incompatible with the layer: "
                    f"expected min_ndim={spec.min_ndim}, "
                    f"found ndim={ndim}. "
                    f"Full shape received: {tuple(shape)}"
                )
        # Check dtype.
        if spec.dtype is not None:
            if x.dtype.name != spec.dtype:
                raise ValueError(
                    f'Input {input_index} of layer "{layer_name}" '
                    "is incompatible with the layer: "
                    f"expected dtype={spec.dtype}, "
                    f"found dtype={x.dtype}"
                )

        # Check specific shape axes.
        shape_as_list = shape.as_list()
        if spec.axes:
            for axis, value in spec.axes.items():
                if hasattr(value, "value"):
                    value = value.value
                if value is not None and shape_as_list[int(axis)] not in {
                    value,
                    None,
                }:
                    raise ValueError(
                        f'Input {input_index} of layer "{layer_name}" is '
                        f"incompatible with the layer: expected axis {axis} "
                        f"of input shape to have value {value}, "
                        "but received input with "
                        f"shape {display_shape(x.shape)}"
                    )
        # Check shape.
        if spec.shape is not None and shape.rank is not None:
            spec_shape = spec.shape
            if spec.allow_last_axis_squeeze:
                if shape_as_list and shape_as_list[-1] == 1:
                    shape_as_list = shape_as_list[:-1]
                if spec_shape and spec_shape[-1] == 1:
                    spec_shape = spec_shape[:-1]
            for spec_dim, dim in zip(spec_shape, shape_as_list):
                if spec_dim is not None and dim is not None:
                    if spec_dim != dim:
                        raise ValueError(
                            f'Input {input_index} of layer "{layer_name}" is '
                            "incompatible with the layer: "
                            f"expected shape={spec.shape}, "
                            f"found shape={display_shape(x.shape)}"
                        )


def display_shape(shape):
    return str(tuple(shape.as_list()))


def to_tensor_spec(input_spec, default_dtype=None):
    """Converts a Keras InputSpec object to a TensorSpec."""
    default_dtype = default_dtype or backend.floatx()
    if isinstance(input_spec, InputSpec):
        dtype = input_spec.dtype or default_dtype
        return tf.TensorSpec(to_tensor_shape(input_spec), dtype)
    return tf.TensorSpec(None, default_dtype)
