# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains testing utilities related to mixed precision."""

import tensorflow.compat.v2 as tf

from keras import regularizers
from keras.engine import base_layer


def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
    """Returns a function that asserts it's gradient has a certain value.

    This serves as a hook to assert intermediate gradients have a certain value.
    This returns an identity function. The identity's gradient function is also
    the identity function, except it asserts that the gradient equals
    `expected_gradient` and has dtype `expected_dtype`.

    Args:
      expected_gradient: The gradient function asserts that the gradient is this
        value.
      expected_dtype: The gradient function asserts the gradient has this dtype.

    Returns:
      An identity function whose gradient function asserts the gradient has a
      certain value.
    """

    @tf.custom_gradient
    def _identity_with_grad_check(x):
        """Function that asserts it's gradient has a certain value."""
        x = tf.identity(x)

        def grad(dx):
            """Gradient function that asserts the gradient has a certain
            value."""
            if expected_dtype:
                assert (
                    dx.dtype == expected_dtype
                ), "dx.dtype should be %s but is: %s" % (
                    expected_dtype,
                    dx.dtype,
                )
            expected_tensor = tf.convert_to_tensor(
                expected_gradient, dtype=dx.dtype, name="expected_gradient"
            )
            # Control dependency is to ensure input is available. It's possible
            # the dataset will throw a StopIteration to indicate there is no
            # more data, in which case we don't want to run the assertion.
            with tf.control_dependencies([x]):
                assert_op = tf.compat.v1.assert_equal(dx, expected_tensor)
            with tf.control_dependencies([assert_op]):
                dx = tf.identity(dx)
            return dx

        return x, grad

    # Keras sometimes has trouble serializing Lambda layers with a decorated
    # function. So we define and return a non-decorated function.
    def identity_with_grad_check(x):
        return _identity_with_grad_check(x)

    return identity_with_grad_check


def create_identity_with_nan_gradients_fn(have_nan_gradients):
    """Returns a function that optionally has NaN gradients.

    This serves as a hook to introduce NaN gradients to a model. This returns an
    identity function. The identity's gradient function will check if the
    boolean tensor `have_nan_gradients` is True. If so, the gradient will be
    NaN.  Otherwise, the gradient will also be the identity.

    Args:
      have_nan_gradients: A scalar boolean tensor. If True, gradients will be
        NaN. Otherwise, the gradient function is the identity function.

    Returns:
      An identity function whose gradient function will return NaNs, if
      `have_nan_gradients` is True.
    """

    @tf.custom_gradient
    def _identity_with_nan_gradients(x):
        """Function whose gradient is NaN iff `have_nan_gradients` is True."""
        x = tf.identity(x)

        def grad(dx):
            return tf.cond(
                have_nan_gradients, lambda: dx * float("NaN"), lambda: dx
            )

        return x, grad

    # Keras sometimes has trouble serializing Lambda layers with a decorated
    # function. So we define and return a non-decorated function.
    def identity_with_nan_gradients(x):
        return _identity_with_nan_gradients(x)

    return identity_with_nan_gradients


class AssertTypeLayer(base_layer.Layer):
    """A layer which asserts it's inputs are a certain type."""

    def __init__(self, assert_type=None, **kwargs):
        self._assert_type = (
            tf.as_dtype(assert_type).name if assert_type else None
        )
        super().__init__(**kwargs)

    def assert_input_types(self, inputs):
        """Asserts `inputs` are of the correct type. Should be called in
        call()."""
        if self._assert_type:
            inputs_flattened = tf.nest.flatten(inputs)
            for inp in inputs_flattened:
                assert inp.dtype.base_dtype == self._assert_type, (
                    "Input tensor has type %s which does "
                    "not match assert type %s"
                    % (inp.dtype.name, self._assert_type)
                )


class MultiplyLayer(AssertTypeLayer):
    """A layer which multiplies its input by a scalar variable."""

    def __init__(
        self,
        regularizer=None,
        activity_regularizer=None,
        use_operator=False,
        var_name="v",
        **kwargs
    ):
        """Initializes the MultiplyLayer.

        Args:
          regularizer: The weight regularizer on the scalar variable.
          activity_regularizer: The activity regularizer.
          use_operator: If True, add using the * operator. If False, add using
            tf.multiply.
          var_name: The name of the variable. It can be useful to pass a name
            other than 'v', to test having the attribute name (self.v) being
            different from the variable name.
          **kwargs: Passed to AssertTypeLayer constructor.
        """
        self._regularizer = regularizer
        if isinstance(regularizer, dict):
            self._regularizer = regularizers.deserialize(
                regularizer, custom_objects=globals()
            )
        self._activity_regularizer = activity_regularizer
        if isinstance(activity_regularizer, dict):
            self._activity_regularizer = regularizers.deserialize(
                activity_regularizer, custom_objects=globals()
            )

        self._use_operator = use_operator
        self._var_name = var_name
        super().__init__(
            activity_regularizer=self._activity_regularizer, **kwargs
        )

    def build(self, _):
        self.v = self.add_weight(
            self._var_name,
            (),
            initializer="ones",
            regularizer=self._regularizer,
        )
        self.built = True

    def call(self, inputs):
        self.assert_input_types(inputs)
        return self._multiply(inputs, self.v)

    def _multiply(self, x, y):
        if self._use_operator:
            return x * y
        else:
            return tf.multiply(x, y)

    def get_config(self):
        config = super().get_config()
        config["regularizer"] = regularizers.serialize(self._regularizer)
        config["activity_regularizer"] = regularizers.serialize(
            self._activity_regularizer
        )
        config["use_operator"] = self._use_operator
        config["var_name"] = self._var_name
        config["assert_type"] = self._assert_type
        return config


class MultiplyLayerWithoutAutoCast(MultiplyLayer):
    """Same as MultiplyLayer, but does not use AutoCastVariables."""

    def build(self, _):
        dtype = self.dtype
        if dtype in ("float16", "bfloat16"):
            dtype = "float32"
        self.v = self.add_weight(
            "v",
            (),
            initializer="ones",
            dtype=dtype,
            experimental_autocast=False,
            regularizer=self._regularizer,
        )
        self.built = True

    def call(self, inputs):
        self.assert_input_types(inputs)
        assert self.v.dtype in (tf.float32, tf.float64)
        return self._multiply(inputs, tf.cast(self.v, inputs.dtype))


class IdentityRegularizer(regularizers.Regularizer):
    def __call__(self, x):
        assert x.dtype == tf.float32
        return tf.identity(x)

    def get_config(self):
        return {}


class ReduceSumRegularizer(regularizers.Regularizer):
    def __call__(self, x):
        return tf.reduce_sum(x)

    def get_config(self):
        return {}
