# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities related to distributed training."""

import tensorflow.compat.v2 as tf
from absl import flags

from keras import backend

FLAGS = flags.FLAGS


# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
# no longer needed.
def global_batch_size_supported(distribution_strategy):
    return distribution_strategy.extended._global_batch_size


def call_replica_local_fn(fn, *args, **kwargs):
    """Call a function that uses replica-local variables.

    This function correctly handles calling `fn` in a cross-replica
    context.

    Args:
      fn: The function to call.
      *args: Positional arguments to the `fn`.
      **kwargs: Keyword argument to `fn`.

    Returns:
      The result of calling `fn`.
    """
    # TODO(b/132666209): Remove this function when we support assign_*
    # for replica-local variables.
    strategy = None
    if "strategy" in kwargs:
        strategy = kwargs.pop("strategy")
    else:
        if tf.distribute.has_strategy():
            strategy = tf.distribute.get_strategy()

    # TODO(b/120571621): TPUStrategy does not implement replica-local variables.
    is_tpu = backend.is_tpu_strategy(strategy)
    if (not is_tpu) and strategy and tf.distribute.in_cross_replica_context():
        with strategy.scope():
            return strategy.extended.call_for_each_replica(fn, args, kwargs)
    return fn(*args, **kwargs)


def is_distributed_variable(v):
    """Returns whether `v` is a distributed variable."""
    return isinstance(v, tf.distribute.DistributedValues) and isinstance(
        v, tf.Variable
    )


def get_strategy():
    """Creates a `tf.distribute.Strategy` object from flags.

    Example usage:

    ```python
    strategy = utils.get_strategy()
    with strategy.scope():
      model = tf.keras.Sequential([tf.keras.layers.Dense(10)])

    model.compile(...)
    train_ds, test_ds = ...
    model.fit(train_ds, validation_data=test_ds, epochs=10)
    ```

    Returns:
      `tf.distribute.Strategy` instance.
    """
    cls = FLAGS.keras_distribute_strategy_class
    accepted_strats = {
        "tpu",
        "multi_worker_mirrored",
        "mirrored",
        "parameter_server",
        "one_device",
    }
    if cls == "tpu":
        tpu_addr = FLAGS.keras_distribute_strategy_tpu_addr
        if not tpu_addr:
            raise ValueError(
                "When using a TPU strategy, you must set the flag "
                "`keras_distribute_strategy_tpu_addr` (TPU address)."
            )
        cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=tpu_addr
        )
        tf.config.experimental_connect_to_cluster(cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    elif cls == "multi_worker_mirrored":
        strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    elif cls == "mirrored":
        strategy = tf.distribute.MirroredStrategy()
    elif cls == "parameter_server":
        cluster_resolver = (
            tf.distribute.cluster_resolver.TFConfigClusterResolver()
        )
        strategy = tf.distribute.experimental.ParameterServerStrategy(
            cluster_resolver
        )
    elif cls == "one_device":
        strategy = tf.distribute.OneDeviceStrategy("/gpu:0")
    else:
        raise ValueError(
            "Unknown distribution strategy flag. Received: "
            f"keras_distribute_strategy_class={cls}. "
            f"It should be one of {accepted_strats}"
        )
    return strategy
