# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers to traverse the Dataset dependency structure."""
from six.moves import queue as Queue  # pylint: disable=redefined-builtin

from tensorflow.python.framework import dtypes


OP_TYPES_ALLOWLIST = ["DummyIterationCounter"]
# We allowlist all ops that produce variant tensors as output. This is a bit
# of overkill but the other dataset _inputs() traversal strategies can't
# cover the case of function inputs that capture dataset variants.
TENSOR_TYPES_ALLOWLIST = [dtypes.variant]


def _traverse(dataset, op_filter_fn):
  """Traverse a dataset graph, returning nodes matching `op_filter_fn`."""
  result = []
  bfs_q = Queue.Queue()
  bfs_q.put(dataset._variant_tensor.op)  # pylint: disable=protected-access
  visited = []
  while not bfs_q.empty():
    op = bfs_q.get()
    visited.append(op)
    if op_filter_fn(op):
      result.append(op)
    for i in op.inputs:
      input_op = i.op
      if input_op not in visited:
        bfs_q.put(input_op)
  return result


def obtain_capture_by_value_ops(dataset):
  """Given an input dataset, finds all allowlisted ops used for construction.

  Allowlisted ops are stateful ops which are known to be safe to capture by
  value.

  Args:
    dataset: Dataset to find allowlisted stateful ops for.

  Returns:
    A list of variant_tensor producing dataset ops used to construct this
    dataset.
  """

  def capture_by_value(op):
    return (op.outputs[0].dtype in TENSOR_TYPES_ALLOWLIST or
            op.type in OP_TYPES_ALLOWLIST)

  return _traverse(dataset, capture_by_value)


def obtain_all_variant_tensor_ops(dataset):
  """Given an input dataset, finds all dataset ops used for construction.

  A series of transformations would have created this dataset with each
  transformation including zero or more Dataset ops, each producing a dataset
  variant tensor. This method outputs all of them.

  Args:
    dataset: Dataset to find variant tensors for.

  Returns:
    A list of variant_tensor producing dataset ops used to construct this
    dataset.
  """
  return _traverse(dataset, lambda op: op.outputs[0].dtype == dtypes.variant)
