# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utility functions for control flow.

This file is necessary to avoid cyclic dependencies between ops.py and
control_flow_ops.py.
"""

import os
import traceback

from tensorflow.python import tf2
from tensorflow.python.platform import tf_logging as logging

ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
                           os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
                          os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
                          os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
                          os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
                          os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")


# TODO(b/137793122): Remove this.
def enable_control_flow_v2():  # pylint: disable=invalid-name
  """Use control flow v2.

  Do not use this symbol. This will be removed.
  """
  global ENABLE_CONTROL_FLOW_V2
  ENABLE_CONTROL_FLOW_V2 = True


def EnableControlFlowV2(graph):
  """Returns whether control flow v2 should be used in `graph`."""
  # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs).
  # TODO(skyewm): do something better than hasattr without messing up imports.
  return ENABLE_CONTROL_FLOW_V2 or (
      graph.building_function and not hasattr(graph, "_captured"))


def IsInXLAContext(op):
  try:
    xla_compile = op.get_attr("_XlaCompile")
    if xla_compile: return True
  except ValueError:
    pass
  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  return GetContainingXLAContext(ctxt) is not None


def InXlaContext(graph):
  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
  return GetContainingXLAContext(ctxt) is not None


def GraphOrParentsInXlaContext(graph):
  while True:
    if InXlaContext(graph): return True
    try:
      graph = graph.outer_graph
    except AttributeError:
      return False


def IsInWhileLoop(op):
  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  return GetContainingWhileContext(ctxt) is not None


def IsInCond(op):
  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  return GetContainingCondContext(ctxt) is not None


def IsSwitch(op):
  """Return true if `op` is a Switch."""
  return op.type == "Switch" or op.type == "RefSwitch"


def IsMerge(op):
  """Return true if `op` is a Merge."""
  return op.type == "Merge" or op.type == "RefMerge"


def IsLoopEnter(op):
  """Returns true if `op` is an Enter."""
  return op.type == "Enter" or op.type == "RefEnter"


def IsLoopExit(op):
  """Return true if `op` is an Exit."""
  return op.type == "Exit" or op.type == "RefExit"


def IsCondSwitch(op):
  """Return true if `op` is the Switch for a conditional."""
  if not IsSwitch(op):
    return False
  if not op.outputs:
    return False
  # Switch nodes are not part of the cond control flow context that they
  # represent, so consider the consumers of its outputs to determine if it is
  # cond switch or not. A switch is a cond switch iff all its consumers are in
  # cond contexts.
  is_cond_switch = True
  for o in op.outputs:
    for c in o.consumers():
      ctxt = c._get_control_flow_context()  # pylint: disable=protected-access
      if IsLoopEnter(c):
        ctxt = ctxt.outer_context
      is_cond_switch = is_cond_switch and (ctxt is not None and
                                           ctxt.IsCondContext())
  return is_cond_switch


def IsCondMerge(op):
  """Return true if `op` is the Merge for a conditional."""
  if not IsMerge(op):
    return False
  if not op.inputs:
    return False
  # Merge nodes are not part of the cond control flow context that they
  # represent, so consider the inputs to the merge of to determine if it is
  # cond merge or not: A merge is a cond merge iff all its inputs are in
  # cond contexts.
  is_cond_merge = True
  for i in op.inputs:
    ctxt = GetOutputContext(i.op)
    is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
  return is_cond_merge


def IsLoopSwitch(op):
  """Return true if `op` is the Switch for a while loop."""
  if IsSwitch(op):
    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
    return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
  return False


def IsLoopMerge(op):
  """Return true if `op` is the Merge for a while loop."""
  if IsMerge(op):
    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
    return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
  return False


def IsLoopConstantEnter(op):
  """Return true iff op is a loop invariant."""
  return IsLoopEnter(op) and op.get_attr("is_constant")


def GetLoopConstantEnter(value):
  """Return the enter op if we can infer `value` to be a loop invariant."""
  id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
  op = value.op
  while op.type in id_ops:
    op = op.inputs[0].op
  return op if IsLoopConstantEnter(op) else None


def GetOutputContext(op):
  """Return the control flow context for the output of an op."""
  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  # Exit nodes usually have a control flow context, except in the case where the
  # exit node was imported via import_graph_def (in which case no nodes have
  # control flow contexts).
  if ctxt is not None and IsLoopExit(op):
    ctxt = ctxt.outer_context
  return ctxt


def GetContainingWhileContext(ctxt, stop_ctxt=None):
  """Returns the first ancestor WhileContext of `ctxt`.

  Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
  while loop.

  Args:
    ctxt: ControlFlowContext
    stop_ctxt: ControlFlowContext, optional. If provided, the search will end
      if it sees stop_ctxt.

  Returns:
    `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
    `ctxt`, or None if `ctxt` is not in a while loop.  If `stop_ctxt` is not
    `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
  """
  while ctxt:
    if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
    ctxt = ctxt.outer_context
  return None


def GetContainingXLAContext(ctxt):
  """Returns the first ancestor XLAContext of `ctxt`.

  Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
  while loop.

  Args:
    ctxt: ControlFlowContext

  Returns:
    `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
    `ctxt`, or None if `ctxt` is not in a while loop.
  """
  while ctxt:
    if ctxt.IsXLAContext(): return ctxt
    ctxt = ctxt.outer_context
  return None


def GetContainingCondContext(ctxt):
  """Returns the first ancestor CondContext of `ctxt`.

  Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.

  Args:
    ctxt: ControlFlowContext

  Returns:
    `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
    `ctxt`, or None if `ctxt` is not in a cond.
  """
  while ctxt:
    if ctxt.IsCondContext(): return ctxt
    ctxt = ctxt.outer_context
  return None


def IsContainingContext(ctxt, maybe_containing_ctxt):
  """Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
  while ctxt is not maybe_containing_ctxt:
    if ctxt is None: return False
    ctxt = ctxt.outer_context
  return True


def OpInContext(op, ctxt):
  return IsContainingContext(op._get_control_flow_context(), ctxt)  # pylint: disable=protected-access


def TensorInContext(tensor, ctxt):
  return OpInContext(tensor.op, ctxt)


def CheckInputFromValidContext(op, input_op):
  """Returns whether `input_op` can be used from `op`s context.

  Conceptually, only inputs from op's while context or any ancestor while
  context (including outside of any context) are valid. In practice, there are
  many other edge cases as well.

  Args:
    op: Operation
    input_op: Operation

  Raises:
    ValueError: if input_op is from an invalid context.
  """
  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  input_ctxt = GetOutputContext(input_op)
  valid = False

  if not input_ctxt:
    # input_op isn't in a control flow context.
    valid = True
  elif op_ctxt is input_ctxt:
    # input_op is in the same context as op.
    valid = True
  else:
    while_ctxt = GetContainingWhileContext(op_ctxt)
    input_while_ctxt = GetContainingWhileContext(input_ctxt)

    if while_ctxt is None:
      if input_while_ctxt is None:
        # Neither op nor input_op is in a while loop, but one or both are in
        # conds. We allow this, although execution will fail if the branch
        # corresponding to input_op's cond context isn't taken.
        valid = True
      # Invalid if op isn't in a while loop and input_op is. Unless...
      if IsLoopEnter(op):
        # WhileContext._BuildLoop clears context for Enter nodes.
        valid = True
      if IsSwitch(op):
        # CondContext.AddValue clears context for Switch nodes.
        valid = True
    elif IsContainingContext(while_ctxt, input_while_ctxt):
      # input_op is in a while loop which contains op's while loop (or not in a
      # while loop at all).
      valid = True
    elif (while_ctxt.grad_state and
          IsContainingContext(while_ctxt.grad_state.forward_context,
                              input_while_ctxt)):
      # op is in a gradient context and input_op is in the associated forward
      # pass context or an ancestor thereof. This case is need to build while
      # loop gradients.
      # NOTE(skyewm): we theoretically also need this case for custom gradient
      # functions that close over tensors from ancestor contexts, but I haven't
      # verified this.
      valid = True
    elif (while_ctxt.grad_state and
          while_ctxt.grad_state.forward_context is
          input_while_ctxt._outer_context):  # pylint: disable=protected-access
      # op is in a gradient context and input_op is in a child of the associated
      # forward pass context. This case is needed for the gradients of while
      # loops with conds.
      valid = True
    elif (input_while_ctxt.grad_state and
          input_while_ctxt.grad_state.forward_context is while_ctxt):
      # input_op is in the gradient context of op's context. This case is needed
      # when the gradient of a while loop gradient is requested (this will
      # eventually fail unless there is a stop_gradient() or similar).
      valid = True
    elif (input_while_ctxt.grad_state and
          input_ctxt.grad_state.forward_context.grad_state and
          input_ctxt.grad_state.forward_context.grad_state.forward_context is
          while_ctxt):
      # input_op is in the grad grad context of op's context. This case is
      # needed when the gradient of a while loop gradient is requested (this
      # will eventually fail unless there is a stop_gradient() or similar).
      valid = True

  if not valid:
    if while_ctxt:
      error_msg = (
          f"Cannot use '{input_op.name}' as input to '{op.name}' because they "
          "are in different while loops.")
    else:
      error_msg = (
          f"Cannot use '{input_op.name}' as input to '{op.name}' because "
          f"'{input_op.name}' is in a while loop.")

    # Log the error message plus the relevant stack traces. The stacks may be
    # useful for debugging this error, but we don't want to raise an
    # unreadable exception.
    log_msg = error_msg
    log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
    log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
    log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
        op.name, "".join(traceback.format_list(op.traceback)),
        input_op.name, "".join(traceback.format_list(input_op.traceback)))
    logging.info(log_msg)
    raise ValueError(error_msg + " See info log for more details.")


def GetWhileContext(op):
  """Get the WhileContext to which this op belongs."""
  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
  if ctxt:
    ctxt = ctxt.GetWhileContext()
  return ctxt
