# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visitor restricting traversal to only the public tensorflow API."""

import re

import six

from tensorflow.python.util import tf_inspect


class PublicAPIVisitor(object):
  """Visitor to use with `traverse` to visit exactly the public TF API."""

  def __init__(self, visitor):
    """Constructor.

    `visitor` should be a callable suitable as a visitor for `traverse`. It will
    be called only for members of the public TensorFlow API.

    Args:
      visitor: A visitor to call for the public API.
    """
    self._visitor = visitor
    self._root_name = 'tf'

    # Modules/classes we want to suppress entirely.
    self._private_map = {
        'tf': [
            'compiler',
            'core',
            # TODO(scottzhu): See b/227410870 for more details. Currently
            # dtensor API is exposed under tf.experimental.dtensor, but in the
            # meantime, we have tensorflow/dtensor directory which will be treat
            # as a python package. We want to avoid step into the
            # tensorflow/dtensor directory when visit the API.
            # When the tf.dtensor becomes the public API, it will actually pick
            # up from tf.compat.v2.dtensor as priority and hide the
            # tensorflow/dtensor package.
            'dtensor',
            'python',
        ],
        # Some implementations have this internal module that we shouldn't
        # expose.
        'tf.flags': ['cpp_flags'],
    }

    # Modules/classes we do not want to descend into if we hit them. Usually,
    # system modules exposed through platforms for compatibility reasons.
    # Each entry maps a module path to a name to ignore in traversal.
    self._do_not_descend_map = {
        'tf': [
            'examples',
            'flags',  # Don't add flags
            # TODO(drpng): This can be removed once sealed off.
            'platform',
            # TODO(drpng): This can be removed once sealed.
            'pywrap_tensorflow',
            # TODO(drpng): This can be removed once sealed.
            'user_ops',
            'tools',
            'tensorboard',
        ],

        ## Everything below here is legitimate.
        # It'll stay, but it's not officially part of the API.
        'tf.app': ['flags'],
        # Imported for compatibility between py2/3.
        'tf.test': ['mock'],
    }

  @property
  def private_map(self):
    """A map from parents to symbols that should not be included at all.

    This map can be edited, but it should not be edited once traversal has
    begun.

    Returns:
      The map marking symbols to not include.
    """
    return self._private_map

  @property
  def do_not_descend_map(self):
    """A map from parents to symbols that should not be descended into.

    This map can be edited, but it should not be edited once traversal has
    begun.

    Returns:
      The map marking symbols to not explore.
    """
    return self._do_not_descend_map

  def set_root_name(self, root_name):
    """Override the default root name of 'tf'."""
    self._root_name = root_name

  def _is_private(self, path, name, obj=None):
    """Return whether a name is private."""
    # TODO(wicke): Find out what names to exclude.
    del obj  # Unused.
    return ((path in self._private_map and name in self._private_map[path]) or
            (six.ensure_str(name).startswith('_') and
             not re.match('__.*__$', six.ensure_str(name)) or
             name in ['__base__', '__class__', '__next_in_mro__']))

  def _do_not_descend(self, path, name):
    """Safely queries if a specific fully qualified name should be excluded."""
    return (path in self._do_not_descend_map and
            name in self._do_not_descend_map[path])

  def __call__(self, path, parent, children):
    """Visitor interface, see `traverse` for details."""

    # Avoid long waits in cases of pretty unambiguous failure.
    if tf_inspect.ismodule(parent) and len(
        six.ensure_str(path).split('.')) > 10:
      raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
                         'problem with an accidental public import.' %
                         (self._root_name, path))

    # Includes self._root_name
    full_path = '.'.join([self._root_name, path]) if path else self._root_name

    # Remove things that are not visible.
    for name, child in list(children):
      if self._is_private(full_path, name, child):
        children.remove((name, child))

    self._visitor(path, parent, children)

    # Remove things that are visible, but which should not be descended into.
    for name, child in list(children):
      if self._do_not_descend(full_path, name):
        children.remove((name, child))
