# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""Library for constructing a training loop, suitable for TPUs."""

from typing import Any, Callable, Iterable, List, Optional, Union

from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import tpu_function
from tensorflow.python.types import core as core_types


def while_loop(condition: Callable[..., Any],
               body: Callable[..., Any],
               inputs: Optional[List[Any]] = None,
               infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
               name: Any = None) -> Any:
  """Builds a training loop for TPUs.

  The set of loop-carried tensors corresponds to `inputs`.  Both
  `condition` and `body` take the current value of the loop-carried
  tensors. 'body' additionally takes a tuple of infeed from
  infeed_queue if infeed_queue is not None. `condition` must return a
  single boolean value that determines whether iteration
  continues. `body` must return an updated list of values for the
  loop-carried tensors.

  Args:
    condition: a Python function that builds the loop condition.
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the training loop, or None
      (equivalent to an empty list).
    infeed_queue: if not None, the infeed queue from which to append a tuple of
      arguments as inputs to condition.
    name: (Deprecated) Does nothing.

  Returns:
    The final values of the loop-carried tensors.

  Raises:
    TypeError: if body or condition has the wrong signature.
  """
  del name
  # Converts inputs to Tensors.
  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
                                      x in inputs]
  input_types = [x.dtype for x in inputs]
  input_arity = len(inputs)

  body_arg_error = xla.check_function_argument_count(
      body, input_arity, infeed_queue)
  if body_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          f"Supplied loop body function cannot be called with the specified "
          f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop body needs {body_arg_error}"
      )
    else:
      raise TypeError(
          f"Supplied loop body function cannot be called with the specified "
          f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]} and {infeed_queue.number_of_tuple_elements} additional inputs from "
          f"infeed, but the computation needs {body_arg_error}")
  condition_arg_error = xla.check_function_argument_count(
      condition, input_arity, None)
  if condition_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          f"Supplied loop condition function cannot be called with the "
          f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop "
          f"condition needs {condition_arg_error}")
    else:
      raise TypeError(
          f"Supplied loop condition function cannot be called with the "
          f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop "
          f"condition needs {condition_arg_error}. Note that infeed is not passed to the loop condition."
      )

  def condition_wrapper(*inputs):
    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []
    return condition(*inputs)

  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO(phawkins): in principle this is too restrictive since it serializes
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      output_tensors = control_flow_ops.tuple(output_tensors,
                                              control_inputs=output_operations)

    if tensor_tracer.TensorTracer.is_enabled():
      num_replicas = tpu_function.get_tpu_context().number_of_shards
      if num_replicas is None:
        num_replicas = 1
      tt = tensor_tracer.TensorTracer()
      output_tensors = tt.trace_tpu(ops.get_default_graph(),
                                    output_tensors, None,
                                    num_replicas)
    return output_tensors

  # If the body has arity 0, add a dummy loop-carried value to which we can add
  # control dependencies from any side-effecting operations.
  if input_arity == 0:
    inputs = [array_ops.constant(0)]
  return control_flow_ops.while_loop(
      condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)


def repeat(
    n: int,
    body: Callable[..., Union[core_types.TensorLike, Iterable]],  # pylint:disable=g-bare-generic
    inputs: Optional[List[core_types.TensorLike]] = None,
    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
    name: Any = None) -> List[core_types.TensorLike]:
  """Builds a training loop that executes a fixed number of iterations.

  The set of loop-carried tensors correspond to `inputs`.
  `body` must be a function that takes and returns the values of the
  loop-carried tensors.

  Args:
    n: the number of loop iterations
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the training loop or None
      (equivalent to an empty list).
    infeed_queue: if not None, the infeed queue from which to append a tuple of
      arguments as inputs to condition.
    name: (Deprecated) Does nothing.

  Returns:
    The final values of the loop-carried tensors.
  Raises:
    ValueError: if there is a type error.
  """
  def _convert_to_list(xs):
    if not isinstance(xs, (list, tuple)):
      return [xs]
    else:
      return list(xs)

  def cond(i, *args):
    del args
    return i < n

  def body_wrapper(i, *args):
    return [i + 1] + _convert_to_list(body(*args))

  inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
  outputs = while_loop(
      cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
  outputs = _convert_to_list(outputs)
  if len(outputs) == 1:
    # Returns the Op rather than an empty list.
    return outputs[0].op
  else:
    return outputs[1:]
