# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for control flow.

This file is copied from tensorflow/python/ops/control_flow_util.py.
"""

import tensorflow.compat.v2 as tf


def InXlaContext(graph):
    ctxt = graph._get_control_flow_context()
    return GetContainingXLAContext(ctxt) is not None


def GraphOrParentsInXlaContext(graph):
    while True:
        if InXlaContext(graph):
            return True
        try:
            graph = graph.outer_graph
        except AttributeError:
            return False


def IsInWhileLoop(op):
    ctxt = op._get_control_flow_context()
    return GetContainingWhileContext(ctxt) is not None


def GetContainingWhileContext(ctxt, stop_ctxt=None):
    """Returns the first ancestor WhileContext of `ctxt`.

    Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
    while loop.

    Args:
      ctxt: ControlFlowContext
      stop_ctxt: ControlFlowContext, optional. If provided, the search will end
        if it sees stop_ctxt.

    Returns:
      `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext
      containing `ctxt`, or None if `ctxt` is not in a while loop.  If
      `stop_ctxt` is not `None`, this returns `ctxt` if it matches `stop_ctxt`
      in its traversal.
    """
    while ctxt:
        if ctxt.IsWhileContext() or ctxt == stop_ctxt:
            return ctxt
        ctxt = ctxt.outer_context
    return None


def GetContainingXLAContext(ctxt):
    """Returns the first ancestor XLAContext of `ctxt`.

    Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
    while loop.

    Args:
      ctxt: ControlFlowContext

    Returns:
      `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
      `ctxt`, or None if `ctxt` is not in a while loop.
    """
    while ctxt:
        if ctxt.IsXLAContext():
            return ctxt
        ctxt = ctxt.outer_context
    return None


def smart_cond(pred, true_fn=None, false_fn=None, name=None):
    """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

    If `pred` is a bool or has a constant value, we return either `true_fn()`
    or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

    Args:
      pred: A scalar determining whether to return the result of `true_fn` or
        `false_fn`.
      true_fn: The callable to be performed if pred is true.
      false_fn: The callable to be performed if pred is false.
      name: Optional name prefix when using `tf.cond`.

    Returns:
      Tensors returned by the call to either `true_fn` or `false_fn`.

    Raises:
      TypeError: If `true_fn` or `false_fn` is not callable.
    """
    if isinstance(pred, tf.Variable):
        return tf.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
    return tf.__internal__.smart_cond.smart_cond(
        pred, true_fn=true_fn, false_fn=false_fn, name=name
    )


def constant_value(pred):
    """Return the bool value for `pred`, or None if `pred` had a dynamic value.

    Args:
      pred: A scalar, either a Python bool or a TensorFlow boolean variable
        or tensor, or the Python integer 1 or 0.

    Returns:
      True or False if `pred` has a constant boolean value, None otherwise.

    Raises:
      TypeError: If `pred` is not a Variable, Tensor or bool, or Python
        integer 1 or 0.
    """
    if isinstance(pred, tf.Tensor):
        return tf.get_static_value(pred)
    if pred in {0, 1}:  # Accept 1/0 as valid boolean values
        return bool(pred)
    if isinstance(pred, bool):
        return pred
    if isinstance(pred, tf.Variable):
        return None
    raise TypeError(
        "`pred` must be a Tensor, or a Python bool, or 1 or 0. "
        f"Received: {type(pred)}"
    )
