# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains testing utilities related to mixed precision."""

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest


def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
  """Returns a function that asserts it's gradient has a certain value.

  This serves as a hook to assert intermediate gradients have a certain value.
  This returns an identity function. The identity's gradient function is also
  the identity function, except it asserts that the gradient equals
  `expected_gradient` and has dtype `expected_dtype`.

  Args:
    expected_gradient: The gradient function asserts that the gradient is this
      value.
    expected_dtype: The gradient function asserts the gradient has this dtype.

  Returns:
    An identity function whose gradient function asserts the gradient has a
    certain value.
  """
  @custom_gradient.custom_gradient
  def _identity_with_grad_check(x):
    """Function that asserts it's gradient has a certain value."""
    x = array_ops.identity(x)
    def grad(dx):
      """Gradient function that asserts the gradient has a certain value."""
      if expected_dtype:
        assert dx.dtype == expected_dtype, (
            'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
      expected_tensor = ops.convert_to_tensor_v2_with_dispatch(
          expected_gradient, dtype=dx.dtype, name='expected_gradient')
      # Control dependency is to ensure input is available. It's possible the
      # dataset will throw a StopIteration to indicate there is no more data, in
      # which case we don't want to run the assertion.
      with ops.control_dependencies([x]):
        assert_op = check_ops.assert_equal(dx, expected_tensor)
      with ops.control_dependencies([assert_op]):
        dx = array_ops.identity(dx)
      return dx
    return x, grad
  # Keras sometimes has trouble serializing Lambda layers with a decorated
  # function. So we define and return a non-decorated function.
  def identity_with_grad_check(x):
    return _identity_with_grad_check(x)
  return identity_with_grad_check


def create_identity_with_nan_gradients_fn(have_nan_gradients):
  """Returns a function that optionally has NaN gradients.

  This serves as a hook to introduce NaN gradients to a model. This returns an
  identity function. The identity's gradient function will check if the boolean
  tensor `have_nan_gradients` is True. If so, the gradient will be NaN.
  Otherwise, the gradient will also be the identity.

  Args:
    have_nan_gradients: A scalar boolean tensor. If True, gradients will be NaN.
      Otherwise, the gradient function is the identity function.

  Returns:
    An identity function whose gradient function will return NaNs, if
    `have_nan_gradients` is True.
  """
  @custom_gradient.custom_gradient
  def _identity_with_nan_gradients(x):
    """Function whose gradient is NaN iff `have_nan_gradients` is True."""
    x = array_ops.identity(x)
    def grad(dx):
      return control_flow_ops.cond(
          have_nan_gradients,
          lambda: dx * float('NaN'),
          lambda: dx
      )
    return x, grad
  # Keras sometimes has trouble serializing Lambda layers with a decorated
  # function. So we define and return a non-decorated function.
  def identity_with_nan_gradients(x):
    return _identity_with_nan_gradients(x)
  return identity_with_nan_gradients


class AssertTypeLayer(base_layer.Layer):
  """A layer which asserts it's inputs are a certain type."""

  def __init__(self, assert_type=None, **kwargs):
    self._assert_type = (dtypes.as_dtype(assert_type).name if assert_type
                         else None)
    super(AssertTypeLayer, self).__init__(**kwargs)

  def assert_input_types(self, inputs):
    """Asserts `inputs` are of the correct type. Should be called in call()."""
    if self._assert_type:
      inputs_flattened = nest.flatten(inputs)
      for inp in inputs_flattened:
        assert inp.dtype.base_dtype == self._assert_type, (
            'Input tensor has type %s which does not match assert type %s' %
            (inp.dtype.name, self._assert_type))


class MultiplyLayer(AssertTypeLayer):
  """A layer which multiplies its input by a scalar variable."""

  def __init__(self,
               regularizer=None,
               activity_regularizer=None,
               use_operator=False,
               var_name='v',
               **kwargs):
    """Initializes the MultiplyLayer.

    Args:
      regularizer: The weight regularizer on the scalar variable.
      activity_regularizer: The activity regularizer.
      use_operator: If True, add using the * operator. If False, add using
        tf.multiply.
      var_name: The name of the variable. It can be useful to pass a name other
        than 'v', to test having the attribute name (self.v) being different
        from the variable name.
      **kwargs: Passed to AssertTypeLayer constructor.
    """
    self._regularizer = regularizer
    if isinstance(regularizer, dict):
      self._regularizer = regularizers.deserialize(regularizer,
                                                   custom_objects=globals())
    self._activity_regularizer = activity_regularizer
    if isinstance(activity_regularizer, dict):
      self._activity_regularizer = regularizers.deserialize(
          activity_regularizer, custom_objects=globals())

    self._use_operator = use_operator
    self._var_name = var_name
    super(MultiplyLayer, self).__init__(
        activity_regularizer=self._activity_regularizer, **kwargs)

  def build(self, _):
    self.v = self.add_weight(
        self._var_name, (), initializer='ones', regularizer=self._regularizer)
    self.built = True

  def call(self, inputs):
    self.assert_input_types(inputs)
    return self._multiply(inputs, self.v)

  def _multiply(self, x, y):
    if self._use_operator:
      return x * y
    else:
      return math_ops.multiply(x, y)

  def get_config(self):
    config = super(MultiplyLayer, self).get_config()
    config['regularizer'] = regularizers.serialize(self._regularizer)
    config['activity_regularizer'] = regularizers.serialize(
        self._activity_regularizer)
    config['use_operator'] = self._use_operator
    config['var_name'] = self._var_name
    config['assert_type'] = self._assert_type
    return config


class MultiplyLayerWithoutAutoCast(MultiplyLayer):
  """Same as MultiplyLayer, but does not use AutoCastVariables."""

  def build(self, _):
    dtype = self.dtype
    if dtype in ('float16', 'bfloat16'):
      dtype = 'float32'
    self.v = self.add_weight(
        'v', (),
        initializer='ones',
        dtype=dtype,
        experimental_autocast=False,
        regularizer=self._regularizer)
    self.built = True

  def call(self, inputs):
    self.assert_input_types(inputs)
    assert self.v.dtype in (dtypes.float32, dtypes.float64)
    return self._multiply(inputs, math_ops.cast(self.v, inputs.dtype))


class IdentityRegularizer(regularizers.Regularizer):

  def __call__(self, x):
    assert x.dtype == dtypes.float32
    return array_ops.identity(x)

  def get_config(self):
    return {}


class ReduceSumRegularizer(regularizers.Regularizer):

  def __call__(self, x):
    return math_ops.reduce_sum(x)

  def get_config(self):
    return {}
