# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Part of the Keras training engine related to distributed training."""

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend
from keras import callbacks as cbks
from keras.distribute import distribute_coordinator_utils as dc
from keras.distribute import distributed_training_utils_v1 as dist_utils
from keras.engine import partial_batch_padding_handler as padding_util
from keras.engine import training_arrays_v1
from keras.engine import training_utils_v1
from keras.utils.generic_utils import Progbar
from keras.utils.mode_keys import ModeKeys

# isort: off
from tensorflow.python.distribute import input_lib
from tensorflow.python.platform import tf_logging as logging


def _per_replica_execution_function(model, mode):
    exec_func = model._make_execution_function(mode)
    return (
        exec_func.inputs,
        exec_func.outputs,
        exec_func.updates_op,
        exec_func.session_kwargs,
    )


def _build_model(strategy, model, mode, inputs, targets=None):
    if model._compile_distribution:
        dist_utils.clone_model_on_replicas(
            model, strategy, mode, inputs=inputs, targets=targets
        )
    else:
        dist_utils._build_distributed_network(
            model, strategy, mode, inputs, targets
        )


def _make_train_step_fn(model, mode, strategy, output_labels):
    """Create step fn.

    Args:
      model: a Keras Model instance.
      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
      strategy: a `tf.distribute.Strategy` instance.
      output_labels: the output labels for the step function.

    Returns:
      A step function to run by `tf.distribute.Strategy`.
    """

    def _step_fn(ctx, inputs):
        """A step fn that returns update ops."""
        if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
            inputs, targets = inputs
        else:
            targets = None

        # When input feature is a dictionary of tensors, dictionary is
        # flattended to an array and passed as a model input. This results in
        # input mismatch when model input layer names are not sorted in
        # alphabetical order as `nest.flatten()`sorts dictionary elements by
        # keys. As so, transform input tensors into an array and order it along
        # `model._feed_input_names`.
        if isinstance(inputs, dict):
            inputs = [
                inputs[input_name] for input_name in model._feed_input_names
            ]

        _build_model(strategy, model, mode, inputs, targets)

        (
            grouped_inputs,
            grouped_outputs,
            grouped_updates,
            grouped_session_args,
        ) = strategy.extended.call_for_each_replica(
            _per_replica_execution_function,
            args=(dist_utils.get_distributed_model(model, mode), mode),
        )
        (
            all_inputs,
            all_outputs,
            all_updates,
            all_session_args,
        ) = dist_utils.unwrap_values(
            strategy,
            grouped_inputs,
            grouped_outputs,
            grouped_updates,
            grouped_session_args,
        )
        combined_fn = backend.function(
            all_inputs,
            all_outputs,
            updates=all_updates,
            name="distributed_" + str(mode) + "_function",
            **all_session_args
        )

        for label, output in zip(output_labels, combined_fn.outputs):
            if label == "loss":
                reduce_op = tf.distribute.ReduceOp.SUM
            else:
                # We reduce all other metrics using mean for now. This is
                # temporary workaround until new metrics are in place.
                reduce_op = tf.distribute.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the
        # combined_fn: feed_dict, session kwargs, run options, run_metadata for
        # now. These should be handled appropriately
        return combined_fn.updates_op

    return _step_fn


def experimental_tpu_fit_loop(
    model,
    dataset,
    epochs=100,
    verbose=1,
    callbacks=None,
    initial_epoch=0,
    steps_per_epoch=None,
    val_dataset=None,
    validation_steps=None,
    validation_freq=1,
):
    """Fit loop for training with TPU tf.distribute.Strategy.

    Args:
        model: Keras Model instance.
        dataset: Dataset that returns inputs and targets
        epochs: Number of times to iterate over the data
        verbose: Integer, Verbosity mode, 0, 1 or 2
        callbacks: List of callbacks to be called during training
        initial_epoch: Epoch at which to start training
            (useful for resuming a previous training run)
        steps_per_epoch: Total number of steps (batches of samples)
            before declaring one epoch finished and starting the
            next epoch. Ignored with the default value of `None`.
        val_dataset: Dataset for validation data.
        validation_steps: Number of steps to run validation for
            (only if doing validation from data tensors).
            Ignored with the default value of `None`.
        validation_freq: Only relevant if validation data is provided. Integer
            or `collections.abc.Container` instance (e.g. list, tuple, etc.). If
            an integer, specifies how many training epochs to run before a new
            validation run is performed, e.g. `validation_freq=2` runs
            validation every 2 epochs. If a Container, specifies the epochs on
            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
            validation at the end of the 1st, 2nd, and 10th epochs.

    Returns:
        Returns `None`.

    Raises:
        ValueError: in case of invalid arguments.
    """
    mode = ModeKeys.TRAIN

    current_strategy = model._distribution_strategy
    iteration_value = min(
        steps_per_epoch, current_strategy.extended.steps_per_run
    )
    steps_per_run = backend.variable(
        value=iteration_value, dtype="int32", name="steps_per_run"
    )

    # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
    iterator = dist_utils.get_iterator(dataset, current_strategy)

    scope = dist_utils.distributed_scope(
        strategy=current_strategy, learning_phase=1
    )
    scope.__enter__()

    out_labels = model.metrics_names or []

    step_fn = _make_train_step_fn(
        model, ModeKeys.TRAIN, current_strategy, out_labels
    )

    # Add initial dummy values for loss and other metric tensors.
    initial_loop_values = {}
    initial_loop_values["loss"] = tf.constant(1e7)
    for m in model._get_training_eval_metrics():
        tensor = m.result()
        initial_loop_values[m.name] = tf.zeros(tensor.shape, tensor.dtype)

    ctx = current_strategy.extended.experimental_run_steps_on_iterator(
        step_fn,
        iterator,
        iterations=steps_per_run,
        initial_loop_values=initial_loop_values,
    )
    train_op = ctx.run_op
    output_tensors = ctx.last_step_outputs

    do_validation = bool(validation_steps)

    if model._compile_distribution:
        dist_utils._copy_weights_to_distributed_model(model, mode)

    callbacks = cbks.configure_callbacks(
        callbacks,
        model,
        do_validation=do_validation,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        verbose=verbose,
        count_mode="steps",
        mode=mode,
    )

    # Calculate the steps each time on the device.
    steps_to_run = [current_strategy.extended.steps_per_run] * (
        steps_per_epoch // current_strategy.extended.steps_per_run
    )
    if steps_per_epoch % current_strategy.extended.steps_per_run:
        steps_to_run.append(
            steps_per_epoch % current_strategy.extended.steps_per_run
        )
    target_steps = len(steps_to_run)

    callbacks._call_begin_hook(mode)

    initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
        initial_epoch, mode
    )

    for epoch in range(initial_epoch, epochs):
        dist_utils._reset_metrics(model)
        callbacks.on_epoch_begin(epoch)
        epoch_logs = {}
        step_index = 0
        prev_step_count = None
        current_step = 0
        while current_step < target_steps:
            step_count = steps_to_run[current_step]
            batch_logs = {
                "batch": step_index,
                "size": 1,
                "num_steps": step_count,
            }
            callbacks._call_batch_hook(mode, "begin", step_index, batch_logs)
            if prev_step_count is None or step_count != prev_step_count:
                backend.get_session().run(steps_per_run.assign(step_count))
                prev_step_count = step_count
            try:
                _, outputs = backend.batch_get_value([train_op, output_tensors])
            except tf.errors.OutOfRangeError:
                logging.warning(
                    "Your dataset iterator ran out of data; "
                    "interrupting training. Make sure that your dataset "
                    "can generate at least `steps_per_epoch * epochs` "
                    "batches (in this case, %d batches)."
                    % steps_per_epoch
                    * epochs
                )
                break

            batch_logs.update(outputs)
            callbacks._call_batch_hook(mode, "end", step_index, batch_logs)
            step_index = step_index + step_count
            current_step += 1

            if callbacks.model.stop_training:
                break

        if do_validation and training_utils_v1.should_run_validation(
            validation_freq, epoch
        ):
            logging.info("Running validation at fit epoch: %s", epoch)

            if model._compile_distribution:
                # Since we create a new clone from the original model we need to
                # copy the weights back to the original model before we can run
                # validation.
                dist_utils._copy_weights_to_original_model(
                    model, ModeKeys.TRAIN
                )

            val_outs = experimental_tpu_test_loop(
                model,
                val_dataset,
                steps=validation_steps,
                verbose=verbose,
                callbacks=callbacks,
            )
            if not isinstance(val_outs, list):
                val_outs = [val_outs]
            # Same labels assumed.
            for label, val_out in zip(out_labels, val_outs):
                epoch_logs["val_" + label] = val_out

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callbacks.model.stop_training:
            break
    model._successful_loop_finish = True
    callbacks._call_end_hook(mode)

    if model._compile_distribution:
        # Copy the weights back from the replicated model to the original model.
        dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
    scope.__exit__(None, None, None)
    return model.history


def experimental_tpu_test_loop(
    model, dataset, verbose=0, steps=None, callbacks=None
):
    """Test loop for evaluating with TPU tf.distribute.Strategy.

    Args:
        model: Keras Model instance.
        dataset: Dataset for input data.
        verbose: Integer, Verbosity mode 0 or 1.
        steps: Total number of steps (batches of samples)
            before declaring predictions finished.
            Ignored with the default value of `None`.
        callbacks: List of callbacks to be called during training

    Returns:
        Scalar loss (if the model has a single output and no metrics)
        or list of scalars (if the model has multiple outputs
        and/or metrics). The attribute `model.metrics_names` will give you
        the display labels for the outputs.
    """
    mode = ModeKeys.TEST
    current_strategy = model._distribution_strategy
    iterator = dist_utils.get_iterator(dataset, current_strategy)

    scope = dist_utils.distributed_scope(
        strategy=current_strategy, learning_phase=0
    )
    scope.__enter__()

    out_labels = model.metrics_names

    def _test_step_fn(inputs):
        """A fn that returns output of single test step."""
        if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
            inputs, targets = inputs
        else:
            targets = None

        (
            tf.distribute.get_replica_context().merge_call(
                _build_model, args=(model, mode, inputs, targets)
            )
        )

        (_, outputs, updates, _) = _per_replica_execution_function(
            dist_utils.get_distributed_model(model, mode), mode
        )
        with tf.control_dependencies([updates]):
            return [tf.identity(out) for out in outputs]

    test_input_data = iterator.get_next()
    per_replica_outputs = current_strategy.run(
        _test_step_fn, args=(test_input_data,)
    )
    output_tensors = {}
    for label, output in zip(out_labels, per_replica_outputs):
        if label == "loss":
            reduce_op = tf.distribute.ReduceOp.SUM
        else:
            # We reduce all other metrics using mean for now. This is temporary
            # workaround until new metrics are in place.
            reduce_op = tf.distribute.ReduceOp.MEAN
        output_tensors[label] = current_strategy.reduce(
            reduce_op, output, axis=None
        )
    test_op = tf.group(list(output_tensors.values()))

    if verbose >= 1:
        progbar = Progbar(target=steps)

    if model._compile_distribution:
        dist_utils._copy_weights_to_distributed_model(model, mode)

    dist_utils._reset_metrics(model)

    callbacks = cbks.configure_callbacks(
        callbacks,
        model,
        do_validation=False,
        epochs=1,
        steps_per_epoch=steps,
        verbose=verbose,
        count_mode="steps",
        mode=ModeKeys.TEST,
    )
    callbacks._call_begin_hook(mode)

    outs = [0.0] * len(model.metrics_names)
    if steps is not None:
        target_steps = steps
    else:
        raise ValueError(
            "Number of steps could not be inferred from the data, "
            "please pass the steps argument."
        )

    current_step = 0
    while current_step < target_steps:
        batch_logs = {"batch": current_step, "size": 1}
        callbacks._call_batch_hook(mode, "begin", current_step, batch_logs)
        try:
            _, batch_outs = backend.batch_get_value([test_op, output_tensors])
        except tf.errors.OutOfRangeError:
            warning_msg = (
                "Make sure that your dataset can generate at least "
                "`steps` batches (in this case, {} batches).".format(steps)
            )

            logging.warning(
                "Your dataset iterator ran out of data; "
                "interrupting evaluation. " + warning_msg
            )
            target_steps = current_step
            break
        for i, label in enumerate(model.metrics_names):
            if i == 0:
                # Loss is stateless metrics.
                outs[i] += batch_outs[label]
            else:
                # For all stateful metrics, the aggregation is handled by
                # mirrored vars.
                outs[i] = batch_outs[label]

        batch_logs = cbks.make_logs(model, batch_logs, outs, mode)
        callbacks._call_batch_hook(mode, "end", current_step, batch_logs)
        if verbose == 1:
            progbar.update(current_step + 1)
        current_step += 1

    if verbose >= 1:
        # Progress bar finishes at the end.
        progbar.update(target_steps)
    callbacks._call_end_hook(mode)

    scope.__exit__(None, None, None)
    if len(outs) >= 0:
        outs[0] /= target_steps

    if len(outs) == 1:
        return outs[0]
    return outs


def experimental_tpu_predict_loop(
    model, dataset, verbose=0, steps=None, callbacks=None
):
    """Predict loop for predicting with TPU tf.distribute.Strategy.

    Args:
        model: Keras Model instance.
        dataset: Dataset for input data.
        verbose: Integer, Verbosity mode 0 or 1.
        steps: Total number of steps (batches of samples)
            before declaring `_predict_loop` finished.
            Ignored with the default value of `None`.
        callbacks: List of callbacks to be called during training

    Returns:
        Array of predictions (if the model has a single output)
        or list of arrays of predictions
        (if the model has multiple outputs).
    """
    mode = ModeKeys.PREDICT
    dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
    padding_handler = None
    if not dataset_fully_shaped:
        # TODO(hongjunchoi): Investigate whether operations from
        # PartialBatchPaddingHandler are unnecessarily pruned out
        # during graph optimization.
        padding_handler = padding_util.PartialBatchPaddingHandler(
            model._feed_output_shapes
        )
        batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(
            dataset
        )
        padding_handler.padded_batch_size = batch_size
        padding_handler.padding_mask = dataset.reduce(
            padding_handler.padding_mask, padding_handler.update_mask
        )

        dataset = dataset.map(padding_handler.pad_batch)
        dataset = dataset.unbatch()
        # Upon this point, it is guaranteed that the dataset does not
        # have partial batches. Thus, we set `drop_remainder=True` to
        # get static shape information about the elements in the dataset.
        dataset = dataset.batch(batch_size, drop_remainder=True)

        if prefetch_buffer is not None:
            dataset = dataset.prefetch(prefetch_buffer)

    current_strategy = model._distribution_strategy
    iterator = dist_utils.get_iterator(dataset, current_strategy)

    scope = dist_utils.distributed_scope(
        strategy=current_strategy, learning_phase=0
    )
    scope.__enter__()

    def _predict_step_fn(inputs):
        """A fn that returns output of single prediction step."""

        (
            tf.distribute.get_replica_context().merge_call(
                _build_model, args=(model, mode, inputs)
            )
        )

        (_, outputs, updates, _) = _per_replica_execution_function(
            dist_utils.get_distributed_model(model, mode), mode
        )

        with tf.control_dependencies([updates]):
            return [tf.identity(out) for out in outputs]

    # TODO(hongjunchoi): When numpy array is passed as an input to `predict()`
    # use numpy arrays directly to avoid cumulating unnecessary input pipeline
    # ops.
    predict_input_data = iterator.get_next()
    per_replica_outputs = current_strategy.run(
        _predict_step_fn, args=(predict_input_data,)
    )
    output_tensors = dist_utils.flatten_per_replica_values(
        current_strategy, per_replica_outputs
    )

    if verbose >= 1:
        progbar = Progbar(target=steps)

    if model._compile_distribution:
        dist_utils._copy_weights_to_distributed_model(model, mode)

    dist_utils._reset_metrics(model)

    callbacks = cbks.configure_callbacks(
        callbacks,
        model,
        do_validation=False,
        epochs=1,
        steps_per_epoch=steps,
        verbose=verbose,
        count_mode="steps",
        mode=mode,
    )
    callbacks._call_begin_hook(mode)

    # Since we do not know how many samples we will see, we cannot pre-allocate
    # the returned Numpy arrays. Instead, we store one array per batch seen
    # and concatenate them upon returning.
    num_model_outputs = len(model.output_names)
    unconcatenated_outs = [[] for _ in range(num_model_outputs)]
    if steps is not None:
        target_steps = steps
    else:
        raise ValueError(
            "Number of steps could not be inferred from the data, "
            "please pass the steps argument."
        )

    current_step = 0
    while current_step < target_steps:
        batch_logs = {"batch": current_step, "size": 1}
        callbacks._call_batch_hook(mode, "begin", current_step, batch_logs)
        try:
            predict_ops = tf.group(output_tensors)
            _, batch_outs = backend.batch_get_value(
                [predict_ops, output_tensors]
            )

        except tf.errors.OutOfRangeError:
            warning_msg = (
                "Make sure that your dataset can generate at least "
                "`steps` batches (in this case, {} batches).".format(steps)
            )

            logging.warning(
                "Your dataset iterator ran out of data; "
                "interrupting evaluation. " + warning_msg
            )
            break

        # TODO(priyag): maybe need to unwrap the outputs first for
        # MirroredStrategy.
        for i in range(num_model_outputs):
            output_start_index = i * current_strategy.num_replicas_in_sync
            output_end_index = (
                output_start_index + current_strategy.num_replicas_in_sync
            )
            single_model_output = batch_outs[
                output_start_index:output_end_index
            ]
            unconcatenated_outs[i].extend(single_model_output)

        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
        callbacks._call_batch_hook(mode, "end", current_step, batch_logs)
        if verbose == 1:
            progbar.update(current_step + 1)
        current_step += 1

    if verbose >= 1:
        # Progress bar finishes at the end.
        progbar.update(current_step)

    callbacks._call_end_hook(mode)

    scope.__exit__(None, None, None)

    if len(unconcatenated_outs) == 1:
        prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
    else:
        prediction_result = [
            np.concatenate(out, axis=0) for out in unconcatenated_outs
        ]

    if padding_handler:
        prediction_result = padding_handler.apply_mask(prediction_result)

    return prediction_result


class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
    """Training loop for distribution strategy with single worker."""

    def fit(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        epochs=1,
        verbose=1,
        callbacks=None,
        validation_split=0.0,
        validation_data=None,
        shuffle=True,
        class_weight=None,
        sample_weight=None,
        initial_epoch=0,
        steps_per_epoch=None,
        validation_steps=None,
        validation_freq=1,
        **kwargs
    ):
        """Fit loop for Distribution Strategies."""
        dist_utils.validate_callbacks(
            input_callbacks=callbacks, optimizer=model.optimizer
        )
        dist_utils.validate_inputs(x, y)

        batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
            model._distribution_strategy,
            x,
            batch_size,
            steps_per_epoch,
            ModeKeys.TRAIN,
            validation_split=validation_split,
        )
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x
        )
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size,
            validation_split=validation_split,
            shuffle=shuffle,
            epochs=epochs,
        )
        if not dist_utils.is_distributing_by_cloning(model):
            with model._distribution_strategy.scope():
                (dataset, _, _) = model._standardize_user_data(
                    dataset,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                    batch_size=batch_size,
                    validation_split=validation_split,
                    shuffle=shuffle,
                )

        val_dataset = None
        if validation_data:
            (
                val_x,
                val_y,
                val_sample_weights,
            ) = training_utils_v1.unpack_validation_data(validation_data)
            dist_utils.validate_inputs(val_x, val_y)
            _, validation_steps = dist_utils.process_batch_and_step_size(
                model._distribution_strategy,
                val_x,
                batch_size,
                validation_steps,
                ModeKeys.TEST,
            )

            val_dataset = model._distribution_standardize_user_data(
                val_x,
                val_y,
                sample_weight=val_sample_weights,
                class_weight=None,
                batch_size=batch_size,
                validation_split=validation_split,
                shuffle=shuffle,
                allow_partial_batch=True,
            )
        elif validation_split:
            raise ValueError(
                "validation_split argument is not supported with "
                "distribution strategies."
            )

        if backend.is_tpu_strategy(model._distribution_strategy):
            steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
                model,
                dataset,
                steps_per_epoch,
                epochs,
                steps_name="steps_per_epoch",
            )
            if steps_per_epoch is None:
                raise ValueError(
                    "Number of steps could not be inferred from the data, "
                    "please pass the steps_per_epoch argument."
                )

            if not tf.executing_eagerly():
                # Run TPU training in a custom loop in graph mode.
                return experimental_tpu_fit_loop(
                    model,
                    dataset,
                    epochs=epochs,
                    verbose=verbose,
                    callbacks=callbacks,
                    val_dataset=val_dataset,
                    initial_epoch=initial_epoch,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    validation_freq=validation_freq,
                )

        return training_arrays_v1.fit_loop(
            model,
            dataset,
            batch_size=batch_size,
            epochs=epochs,
            verbose=verbose,
            callbacks=callbacks,
            val_inputs=val_dataset,
            shuffle=shuffle,
            initial_epoch=initial_epoch,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
            steps_name="steps_per_epoch",
        )

    def evaluate(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        verbose=1,
        sample_weight=None,
        steps=None,
        callbacks=None,
        **kwargs
    ):
        """Evaluate loop for Distribution Strategies."""
        dist_utils.validate_inputs(x, y)
        batch_size, steps = dist_utils.process_batch_and_step_size(
            model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST
        )
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        dataset = model._distribution_standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            batch_size=batch_size,
            allow_partial_batch=True,
        )

        if backend.is_tpu_strategy(model._distribution_strategy):
            steps = training_utils_v1.infer_steps_for_dataset(
                model, dataset, steps, steps_name="steps"
            )
            if steps is None:
                raise ValueError(
                    "Number of steps could not be inferred from the data, "
                    "please pass the steps argument."
                )

            if not tf.executing_eagerly():
                # Run TPU evaluation in a custom loop in graph mode.
                return experimental_tpu_test_loop(
                    model,
                    dataset,
                    verbose=verbose,
                    steps=steps,
                    callbacks=callbacks,
                )

        return training_arrays_v1.test_loop(
            model,
            inputs=dataset,
            batch_size=batch_size,
            verbose=verbose,
            steps=steps,
            callbacks=callbacks,
        )

    def predict(
        self,
        model,
        x,
        batch_size=None,
        verbose=0,
        steps=None,
        callbacks=None,
        **kwargs
    ):
        """Predict loop for Distribution Strategies."""
        dist_utils.validate_inputs(x=x, y=None)
        batch_size, steps = dist_utils.process_batch_and_step_size(
            model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT
        )
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        dataset = model._distribution_standardize_user_data(
            x, batch_size=batch_size, allow_partial_batch=True
        )
        if backend.is_tpu_strategy(model._distribution_strategy):
            steps = training_utils_v1.infer_steps_for_dataset(
                model, dataset, steps, steps_name="steps"
            )
            if steps is None:
                raise ValueError(
                    "Number of steps could not be inferred from the data, "
                    "please pass the steps argument."
                )
            if not tf.executing_eagerly():
                return experimental_tpu_predict_loop(
                    model,
                    dataset,
                    verbose=verbose,
                    steps=steps,
                    callbacks=callbacks,
                )
        return training_arrays_v1.predict_loop(
            model,
            dataset,
            batch_size=batch_size,
            verbose=verbose,
            steps=steps,
            callbacks=callbacks,
        )


def _train_with_multi_worker(method):
    """Decorator handles multi worker training with distribution strategy."""

    def wrapper(model, **kwargs):
        def _worker_fn(_):
            callbacks = kwargs.pop("callbacks", None)
            filtered_callbacks = dist_utils.filter_distributed_callbacks(
                callbacks, model
            )
            kwargs["callbacks"] = filtered_callbacks
            return method(model, **kwargs)

        return dc.run_distribute_coordinator(
            _worker_fn, model._distribution_strategy
        )

    return wrapper


class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
    """Training loop for distribution strategy with multiple worker."""

    def __init__(self, single_worker_loop):
        self._single_worker_loop = single_worker_loop

    def fit(self, *args, **kwargs):
        return _train_with_multi_worker(self._single_worker_loop.fit)(
            *args, **kwargs
        )

    def evaluate(self, *args, **kwargs):
        return _train_with_multi_worker(self._single_worker_loop.evaluate)(
            *args, **kwargs
        )

    def predict(self, *args, **kwargs):
        # Currently predict is still using the single worker implementation.
        return self._single_worker_loop.predict(*args, **kwargs)
