"""Manages a Trackable object graph."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import weakref

from tensorflow.python.trackable import base
from tensorflow.python.trackable import converter
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export


@tf_export("train.TrackableView", v1=[])
class TrackableView(object):
  """Gathers and serializes a trackable view.

  Example usage:

  >>> class SimpleModule(tf.Module):
  ...   def __init__(self, name=None):
  ...     super().__init__(name=name)
  ...     self.a_var = tf.Variable(5.0)
  ...     self.b_var = tf.Variable(4.0)
  ...     self.vars = [tf.Variable(1.0), tf.Variable(2.0)]

  >>> root = SimpleModule(name="root")
  >>> root.leaf = SimpleModule(name="leaf")
  >>> trackable_view = tf.train.TrackableView(root)

  Pass root to tf.train.TrackableView.children() to get the dictionary of all
  children directly linked to root by name.
  >>> trackable_view_children = trackable_view.children(root)
  >>> for item in trackable_view_children.items():
  ...   print(item)
  ('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
  ('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
  ('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
  numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
  ('leaf', ...)

  """

  def __init__(self, root):
    """Configure the trackable view.

    Args:
      root: A `Trackable` object whose variables (including the variables of
        dependencies, recursively) should be saved. May be a weak reference.
    """
    # TrackableView should never contain a strong reference to root, since it
    # may result in a cycle:
    #   root -> deferred dependencies -> CheckpointPosition
    #   -> CheckpointRestoreCoordinator -> TrackableView -> root
    self._root_ref = (root if isinstance(root, weakref.ref)
                      else weakref.ref(root))

  @classmethod
  def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
    """Returns all child trackables attached to obj.

    Args:
      obj: A `Trackable` object.
      save_type: A string, can be 'savedmodel' or 'checkpoint'.
      **kwargs: kwargs to use when retrieving the object's children.

    Returns:
      Dictionary of all children attached to the object with name to trackable.
    """
    # pylint: disable=protected-access
    obj._maybe_initialize_trackable()
    children = {}
    for name, ref in obj._trackable_children(save_type, **kwargs).items():
      ref = converter.convert_to_trackable(ref, parent=obj)
      children[name] = ref
    return children

  @property
  def root(self):
    if isinstance(self._root_ref, weakref.ref):
      derefed = self._root_ref()
      assert derefed is not None
      return derefed
    else:
      return self._root_ref

  def descendants(self):
    """Returns a list of all nodes from self.root using a breadth first traversal."""
    return self._descendants_with_paths()[0]

  def _descendants_with_paths(self):
    """Returns a list of all nodes and its paths from self.root using a breadth first traversal."""
    bfs_sorted = []
    to_visit = collections.deque([self.root])
    node_paths = object_identity.ObjectIdentityDictionary()
    node_paths[self.root] = ()
    while to_visit:
      current_trackable = to_visit.popleft()
      bfs_sorted.append(current_trackable)
      for name, dependency in self.children(current_trackable).items():
        if dependency not in node_paths:
          node_paths[dependency] = (
              node_paths[current_trackable] +
              (base.TrackableReference(name, dependency),))
          to_visit.append(dependency)
    return bfs_sorted, node_paths
