# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains LossScale classes."""
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export


@deprecation.deprecated_endpoints(
    'train.experimental.MixedPrecisionLossScaleOptimizer')
@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer',
               'train.experimental.MixedPrecisionLossScaleOptimizer'])
class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
  """An optimizer that applies loss scaling.

  Loss scaling is a process that multiplies the loss by a multiplier called the
  loss scale, and divides each gradient by the same multiplier. The pseudocode
  for this process is:

  ```
  loss = ...
  loss *= loss_scale
  grads = gradients(loss, vars)
  grads /= loss_scale
  ```

  Mathematically, loss scaling has no effect, but can help avoid numerical
  underflow in intermediate gradients when float16 tensors are used for mixed
  precision training. By multiplying the loss, each intermediate gradient will
  have the same multiplier applied.

  The loss scale can either be a fixed constant, chosen by the user, or be
  dynamically determined. Dynamically determining the loss scale is convenient
  as a loss scale does not have to be explicitly chosen. However it reduces
  performance.

  This optimizer wraps another optimizer and applies loss scaling to it via a
  `LossScale`. Loss scaling is applied whenever gradients are
  computed, such as through `minimize()`.
  """

  def __init__(self, opt, loss_scale):
    if not isinstance(opt, optimizer.Optimizer):
      raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
                       type(opt))
    self._optimizer = opt

    use_locking = opt._use_locking  # pylint: disable=protected-access
    name = opt.get_name()
    super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)

    self._loss_scale = loss_scale_module.get(loss_scale)
    if self._loss_scale is None:
      raise ValueError('loss_scale cannot be None')
    self._track_trackable(self._optimizer, 'base_optimizer')
    self._track_trackable(self._loss_scale, 'loss_scale')

  def _doing_dynamic_loss_scaling(self):
    """Check if `_loss_scale` dynamically manages the loss scale."""
    return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)

  def compute_gradients(self,
                        loss,
                        var_list=None,
                        gate_gradients=optimizer.Optimizer.GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This adjusts the dynamic range of the gradient evaluation by scaling up
    the `loss` value. The gradient values are then scaled back down by the
    reciprocal of the loss scale. This is useful in reduced precision training
    where small gradient values would otherwise underflow the representable
    range.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking no
        arguments which returns the value to minimize. When eager execution is
        enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph under
        the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with the
        corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.
    """
    loss = self._scale_loss(loss)
    grads_and_vars = self._optimizer.compute_gradients(
        loss=loss,
        var_list=var_list,
        gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops,
        grad_loss=grad_loss)

    grads = [g for g, _ in grads_and_vars]
    variables = [v for _, v in grads_and_vars]
    unscaled_grads = self._unscale_grads(grads)
    return list(zip(unscaled_grads, variables))

  def _scale_loss(self, loss):
    loss_scale = self._loss_scale()
    if callable(loss):
      def new_loss():
        loss_val = loss()
        return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
      return new_loss
    else:
      return loss * math_ops.cast(loss_scale, loss.dtype)

  def _unscale_grads(self, grads):
    loss_scale = self._loss_scale()
    loss_scale_reciprocal = 1 / loss_scale
    return [
        None if g is None else self._scale_grad(g, loss_scale_reciprocal)
        for g in grads
    ]

  def _scale_grad(self, grad, loss_scale_reciprocal):
    if isinstance(grad, indexed_slices.IndexedSlices):
      grad_vals = grad.values * loss_scale_reciprocal
      return indexed_slices.IndexedSlices(grad_vals, grad.indices,
                                          grad.dense_shape)
    return grad * loss_scale_reciprocal

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    conditionally applies gradients if all gradient values are finite.
    Otherwise no update is performed (nor is `global_step` incremented).

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the variables
        have been updated.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that conditionally applies the specified gradients. If
      `global_step` was not None, that operation also increments `global_step`.

    Raises:
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    if distribution_strategy_context.in_cross_replica_context():
      raise ValueError('apply_gradients() must be called in a replica context.')

    if not self._doing_dynamic_loss_scaling():
      return self._optimizer.apply_gradients(grads_and_vars, global_step, name)

    replica_context = distribution_strategy_context.get_replica_context()
    grads_and_vars = tuple(grads_and_vars)

    # TODO(nluehr) cleanup GraphKeys.TRAIN_OP
    return replica_context.merge_call(
        self._distributed_apply, args=(grads_and_vars, global_step, name))

  def _distributed_apply(self,
                         distribution,
                         grads_and_vars,
                         global_step=None,
                         name=None):
    """A version of `apply_gradients` for cross replica context.

    When users are in a cross replica strategy, they must call this rather than
    `apply_gradients()`.

    Args:
      distribution: a `DistributionStrategy` object.
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()` and then aggregated across replicas.
      global_step: Optional (mirrored) `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation. Default to the name passed
        to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients across all
      replicas. If `global_step` was not None, that operation also
      increments `global_step`
    """
    name = name if name is not None else self.get_name()
    grads = [g for g, _ in grads_and_vars]
    loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))

    def apply_fn():
      return self._apply_gradients(distribution, grads_and_vars, global_step,
                                   name + '-wrapped')

    maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
                                           control_flow_ops.no_op)
    return control_flow_ops.group(
        maybe_apply_op, loss_scale_update_op, name=name)

  def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
    """Unconditionally apply gradients in cross replica context."""
    update_ops = distribution.extended.call_for_each_replica(
        self._optimizer.apply_gradients,
        args=(grads_and_vars, global_step, name))
    return distribution.group(update_ops)

  def _apply_sparse(self, grad, var):
    """This function should never be called."""
    raise RuntimeError('This function should never be called')

  def _apply_dense(self, grad, var):
    """This function should never be called."""
    raise RuntimeError('This function should never be called')

  def _resource_apply_sparse(self, grad, handle, indices):
    """This function should never be called."""
    raise RuntimeError('This function should never be called')

  def _resource_apply_dense(self, grad, handle):
    """This function should never be called."""
    raise RuntimeError('This function should never be called')

  def variables(self):
    """Returns the variables of the Optimizer."""
    return (self._optimizer.variables() +
            list(self._loss_scale._weights.values()))  # pylint: disable=protected-access
