# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""Contains the `Node` class."""

import collections
import copy
import json

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend
from keras.engine import base_layer_utils
from keras.saving.saved_model import json_utils
from keras.utils import tf_utils

_CONSTANT_VALUE = "_CONSTANT_VALUE"
# Using dict to avoid conflict with constant string tensor.
_COMPOSITE_TYPE = {"_TYPE": "COMPOSITE"}


class Node:
    """A `Node` describes a layer `__call__()` event.

    A Functional model is a DAG with `Node` instances as nodes, and
    `KerasTensor` instances as edges. Nodes aren't `Layer` instances, because a
    single layer could be called multiple times, which would result in graph
    cycles.

    A `__call__()` event involves input tensors (and other input arguments),
    the layer that was called, and the resulting output tensors.
    A `Node` will include all this information.

    Since a single `Layer` could be called multiple times, the `Node` instances
    are stored on layers as a list. Each time a layer is called a node is added
    to `layer._inbound_nodes`. Each time the output of a layer is used by
    another layer, a node is added to `layer._outbound_nodes`.

    Every `KerasTensor` instance has a `KerasHistory` object attached,
    which tracks the `Node` that records the `__call__()` event that created
    the tensor. By recursively walking through `Node` instances
    via the `KerasHistory` metadata of `KerasTensor` instances, once can
    retrieve the entire DAG of a Functional model.

    Args:
        layer: The layer that was called in the `Layer.__call__()`
          event that this node represents.
        call_args: The positional arguments the layer was called with.
        call_kwargs: The keyword arguments the layer was called with.
        outputs: The output tensors of the `Layer.__call__()`
    """

    def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
        call_args = [] if call_args is None else call_args
        call_kwargs = {} if call_kwargs is None else call_kwargs
        outputs = [] if outputs is None else outputs

        self.layer = layer
        self.is_input = not call_args and not call_kwargs

        # These arguments are user-provided. Copy the structures here so that
        # future user modifications do not affect the node's metadata.
        # We copy using map_structure rather than python's shallow or deep copy,
        # because the args can be data structures (so shallow copy is
        # insufficient), but individual values might not support copy.copy
        # or be too expensive to deep copy.
        call_args = tf.nest.map_structure(lambda t: t, call_args)
        call_kwargs = tf.nest.map_structure(lambda t: t, call_kwargs)
        self.outputs = tf.nest.map_structure(lambda t: t, outputs)
        self.call_args = call_args
        self.call_kwargs = call_kwargs

        # Cached for performance.
        self._flat_arguments = tf.nest.flatten(
            (self.call_args, self.call_kwargs)
        )
        # Used to avoid expensive `nest` operations in the most common case.
        self._single_positional_tensor_passed = (
            not self.call_kwargs
            and len(self.call_args) == 1
            and tf.is_tensor(self.call_args[0])
        )

        if not tf.compat.v1.executing_eagerly_outside_functions():
            # Create TensorFlowOpLayers if needed (in TF1)
            for obj in self._flat_arguments:
                if isinstance(
                    obj, tf.Tensor
                ) and base_layer_utils.needs_keras_history(
                    obj, ignore_call_context=True
                ):
                    base_layer_utils.create_keras_history(obj)

        self._keras_inputs = []
        self._keras_inputs_ids_and_indices = []
        for i, ele in enumerate(self._flat_arguments):
            if is_keras_tensor(ele):
                self._keras_inputs.append(ele)
                kt_id = str(id(ele))
                kt_index = i
                self._keras_inputs_ids_and_indices.append((kt_id, kt_index))

        # Wire up Node to Layers.
        self.layer._inbound_nodes.append(self)
        for kt in self.keras_inputs:
            inbound_layer = kt._keras_history.layer
            if inbound_layer is not None:  # `None` for `Input` tensors.
                inbound_layer._outbound_nodes.append(self)

        # Set metadata on outputs.
        node_index = len(self.layer._inbound_nodes) - 1
        for i, tensor in enumerate(tf.nest.flatten(outputs)):
            tensor._keras_history = KerasHistory(
                layer=layer, node_index=node_index, tensor_index=i
            )

        # Cached for performance.
        self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
        self.flat_output_ids = [
            str(id(t)) for t in tf.nest.flatten(self.outputs)
        ]

    @property
    def keras_inputs(self):
        """Tensors input to this node that can be traced back to a
        `keras.Input`."""
        return self._keras_inputs

    @property
    def parent_nodes(self):
        """Returns all the `Node`s whose output this node immediately depends
        on."""
        node_deps = []
        for kt in self.keras_inputs:
            layer = kt._keras_history.layer
            node_index = kt._keras_history.node_index
            if layer is not None:  # `None` for `Input` tensors.
                node_deps.append(layer._inbound_nodes[node_index])
        return node_deps

    def iterate_inbound(self):
        """Yields tuples representing the data inbound from other nodes.

        Yields:
          tuples like: (inbound_layer, node_index, tensor_index, tensor).
        """
        for kt in self.keras_inputs:
            keras_history = kt._keras_history
            layer = keras_history.layer
            node_index = keras_history.node_index
            tensor_index = keras_history.tensor_index
            yield layer, node_index, tensor_index, kt

    def map_arguments(self, tensor_dict):
        """Maps Keras Tensors to computed Tensors using `tensor_dict`."""
        if self._single_positional_tensor_passed:
            # Performance optimization for most common case.
            kt_id, _ = self._keras_inputs_ids_and_indices[0]
            return (tensor_dict[kt_id].pop(),), {}
        else:
            flat_arguments = copy.copy(self._flat_arguments)
            for kt_id, kt_index in self._keras_inputs_ids_and_indices:
                flat_arguments[kt_index] = tensor_dict[kt_id].pop()

            args, kwargs = tf.nest.pack_sequence_as(
                (self.call_args, self.call_kwargs), flat_arguments
            )
            return args, kwargs

    def serialize(self, make_node_key, node_conversion_map):
        """Serializes `Node` for Functional API's `get_config`."""
        # Serialization still special-cases first argument.
        args, kwargs = self.call_args, self.call_kwargs
        inputs, args, kwargs = self.layer._call_spec.split_out_first_arg(
            args, kwargs
        )

        # Treat everything other than first argument as a kwarg.
        arguments = dict(zip(self.layer._call_spec.arg_names[1:], args))
        arguments.update(kwargs)
        kwargs = arguments

        def _serialize_keras_tensor(t):
            """Serializes a single Tensor passed to `call`."""
            if hasattr(t, "_keras_history"):
                kh = t._keras_history
                node_index = kh.node_index
                node_key = make_node_key(kh.layer.name, node_index)
                new_node_index = node_conversion_map.get(node_key, 0)
                return [kh.layer.name, new_node_index, kh.tensor_index]

            if isinstance(t, np.ndarray):
                return t.tolist()

            if isinstance(t, tf.Tensor):
                return backend.get_value(t).tolist()

            # Not using json_utils to serialize both constant Tensor and
            # constant CompositeTensor for saving format backward compatibility.
            if isinstance(t, tf.__internal__.CompositeTensor):
                return (_COMPOSITE_TYPE, json_utils.Encoder().encode(t))

            return t

        kwargs = tf.nest.map_structure(_serialize_keras_tensor, kwargs)
        try:
            json.dumps(kwargs, default=json_utils.get_json_type)
        except TypeError:
            kwarg_types = tf.nest.map_structure(type, kwargs)
            raise TypeError(
                "Layer "
                + self.layer.name
                + " was passed non-JSON-serializable arguments. "
                + "Arguments had types: "
                + str(kwarg_types)
                + ". They cannot be serialized out "
                "when saving the model."
            )

        # `kwargs` is added to each Tensor in the first arg. This should be
        # changed in a future version of the serialization format.
        def serialize_first_arg_tensor(t):
            if is_keras_tensor(t):
                kh = t._keras_history
                node_index = kh.node_index
                node_key = make_node_key(kh.layer.name, node_index)
                new_node_index = node_conversion_map.get(node_key, 0)
                data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
            else:
                # If an element in the first call argument did not originate as
                # a keras tensor and is a constant value, we save it using the
                # format ['_CONSTANT_VALUE', -1,
                # serialized_tensor_or_python_constant] (potentially including
                # serialized kwargs in an optional 4th argument).
                data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
            return tf_utils.ListWrapper(data)

        data = tf.nest.map_structure(serialize_first_arg_tensor, inputs)
        if (
            not tf.nest.is_nested(data)
            and not self.layer._preserve_input_structure_in_config
        ):
            data = [data]
        data = tf_utils.convert_inner_node_data(data)
        return data

    #############################################################
    # Properties for Backwards compatibility.
    # These only check the first input argument
    # As nodes are internal, they may be removed in the future.
    #############################################################

    @property
    def input_tensors(self):
        if self.is_input:
            return [self.outputs]  # Used in `Layer.input`.
        return self.call_args[0]

    @property
    def output_tensors(self):
        if self.is_input:
            return [self.outputs]  # Used in `Layer.input`.
        return self.outputs

    @property
    def input_shapes(self):
        input_shapes = tf.nest.map_structure(
            backend.int_shape, self.input_tensors
        )
        if len(input_shapes) == 1 and not self.is_input:
            return input_shapes[0]
        return input_shapes

    @property
    def output_shapes(self):
        return tf.nest.map_structure(backend.int_shape, self.output_tensors)

    @property
    def outbound_layer(self):
        return self.layer

    @property
    def inbound_layers(self):
        """Return all layers that feed into the current node."""
        if self.is_input:
            return []
        tensor_call_args = [
            x
            for x in self._flat_arguments
            if tf.is_tensor(x) and hasattr(x, "_keras_history")
        ]
        inbound_layers = tf.nest.map_structure(
            lambda t: t._keras_history.layer, tensor_call_args
        )
        if len(inbound_layers) == 1:
            return inbound_layers[0]
        return inbound_layers


class KerasHistory(
    collections.namedtuple(
        "KerasHistory", ["layer", "node_index", "tensor_index"]
    )
):
    """Tracks the Layer call that created a Tensor, for Keras Graph Networks.

    During construction of Keras Graph Networks, this metadata is added to
    each Tensor produced as the output of a Layer, starting with an
    `InputLayer`. This allows Keras to track how each Tensor was produced, and
    this information is later retraced by the `keras.engine.Network` class to
    reconstruct the Keras Graph Network.

    Attributes:
      layer: The Layer that produced the Tensor.
      node_index: The specific call to the Layer that produced this Tensor.
        Layers can be called multiple times in order to share weights. A new
        node is created every time a Layer is called. The corresponding node
        that represents the call event that produced the Tensor can be found at
        `layer._inbound_nodes[node_index]`.
      tensor_index: The output index for this Tensor. Always zero if the Layer
        that produced this Tensor only has one output. Nested structures of
        Tensors are deterministically assigned an index via `nest.flatten`.
    """

    # Added to maintain memory and performance characteristics of `namedtuple`
    # while subclassing.
    __slots__ = ()


def is_keras_tensor(obj):
    return hasattr(obj, "_keras_history")
