# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for keras functional model."""

import tensorflow.compat.v2 as tf

from keras import backend
from keras.engine import input_layer as input_layer_module
from keras.engine import keras_tensor
from keras.engine import node as node_module

_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG = (
    "Found unexpected instance while processing input tensors for keras "
    "functional model. Expecting KerasTensor which is from tf.keras.Input() "
    "or output from keras layer call(). Got: {}"
)


def is_input_keras_tensor(tensor):
    """Check if tensor is directly generated from `tf.keras.Input`.

    This check is useful when constructing the functional model, since we will
    need to clone Nodes and KerasTensors if the model is building from non input
    tensor.

    Args:
      tensor: A `KerasTensor` as inputs to the functional model.

    Returns:
      bool. Whether the tensor is directly generated from `tf.keras.Input`.

    Raises:
      ValueError: if the tensor is not a KerasTensor instance.
    """
    if not node_module.is_keras_tensor(tensor):
        raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(tensor))
    return tensor.node.is_input


def find_nodes_by_inputs_and_outputs(inputs, outputs):
    """Fetch all Nodes in the graph defined by "inputs" and "outputs".

    This method is used to find and then clone Nodes when creating a new
    sub-model from an existing functional model.

    Args:
      inputs: A nested structure of KerasTensor to use as model inputs.
      outputs: A nested structure of KerasTensor to use as model outputs.

    Returns:
      A list of Nodes that are connected to the inputs and outputs.

    Raises:
      ValueError: when inputs and outputs are disconnected or in case of
        unexpected objects in the inputs/outputs.
    """
    # We walk the graph bottom up, starting from output nodes, and keep tracing
    # the upstream node, until we find all the inputs nodes. We don't use top
    # down search here since we don't know whether a certain node is in the
    # graph between inputs and outputs, e.g. a functional graph could have
    # multiple outputs, and the user could choose a subset of them to build the
    # model. The bottom up approach will ensure all the nodes we visit are
    # actually in use. If we reach the top and didn't find the nodes in the
    # `inputs`, that's an error, since the user didn't specify the correct
    # inputs.
    start_keras_tensors = tf.nest.flatten(outputs)
    end_keras_tensors = tf.nest.flatten(inputs)

    for t in start_keras_tensors + end_keras_tensors:
        if not node_module.is_keras_tensor(t):
            raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(t))
    end_ids = set([id(kt) for kt in end_keras_tensors])
    # Track all the end tensors we found so far, if we didn't reach all the
    # user-specified keras inputs after we finish the search, then that's an
    # error since the inputs are disconnected from the outputs.
    end_ids_found = set()

    nodes_to_visit = []
    nodes_in_graph = []
    node_id_visited = set()
    for t in start_keras_tensors:
        nodes_to_visit.append(t.node)

    while nodes_to_visit:
        node = nodes_to_visit.pop(0)
        if id(node) in node_id_visited:
            continue
        node_id_visited.add(id(node))
        nodes_in_graph.append(node)
        # Any input keras_tensor that produce the current node.
        for kt in node.keras_inputs:
            if id(kt) in end_ids:
                # We found the inputs of the model, stop tracing upstream nodes
                end_ids_found.add(id(kt))
                continue

            inbound_node = kt.node
            # In case this is the tf.keras.Input node, we have reached the end
            # of the tracing of upstream nodes. Any further tracing will just be
            # an infinite loop. we should raise an error here since we didn't
            # find the input in the user-specified inputs.
            if inbound_node.is_input:
                raise ValueError(
                    "Found input tensor cannot be reached given provided "
                    "output tensors. Please make sure the tensor {} is "
                    "included in the model inputs when building "
                    "functional model.".format(kt)
                )
            nodes_to_visit.append(inbound_node)

    # Do a final check and make sure we have reached all the user-specified
    # inputs
    if end_ids != end_ids_found:
        unvisited_inputs = [
            kt for kt in end_keras_tensors if id(kt) not in end_ids_found
        ]
        raise ValueError(
            "Found unvisited input tensors that are disconnected from "
            "the outputs: {}".format(unvisited_inputs)
        )
    return nodes_in_graph


def clone_graph_nodes(inputs, outputs):
    """Clone the `Node` between the inputs and output tensors.

    This function is used to create a new functional model from any intermediate
    keras tensors. The clone of the nodes mimic the behavior of reconstructing
    the functional graph network by re-executing all the __call__ methods. The
    cloned nodes will be appended to the layers.

    Note that a new tf.keras.Inputs will be created for any items in the
    `inputs`

    Args:
      inputs: A nested structure of keras_tensors.
      outputs: A nested structure of keras_tensors.

    Returns:
      A pair of inputs and outputs, with cloned keras_tensors. They can be used
      to create a new functional model.
    """
    nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
    cloned_inputs = []
    cloned_outputs = []
    # We not only need to create copies of Nodes (mimic the calls), also need to
    # clone keras_tensors to avoid the override of _keras_history attached on
    # the keras_tensor. The following dict is used to track any keras tensor we
    # cloned The key is the string ID of the original keras tensor, and value is
    # the cloned keras_tensor instance.
    kt_id_mapping = {}

    for kt_input in tf.nest.flatten(inputs):
        if kt_input.node.is_input:
            # For any existing keras_tensor from tf.keras.Input, we leave them
            # as is.
            cloned_inputs.append(kt_input)
            kt_id_mapping[id(kt_input)] = kt_input
        else:
            # We need to create a new tf.keras.Input for any intermediate
            # keras_tensor
            cpy = _clone_keras_tensor(kt_input)
            cloned_input = input_layer_module.Input(tensor=cpy)
            cloned_inputs.append(cloned_input)
            kt_id_mapping[id(kt_input)] = cloned_input
    cloned_inputs = tf.nest.pack_sequence_as(inputs, cloned_inputs)

    for kt_output in tf.nest.flatten(outputs):
        cpy = _clone_keras_tensor(kt_output)
        # We reuse the _keras_history here, which contains the old information.
        # It is used in the Node constructor to check if the tensor
        # "is_keras_tensor()" The history will be override by the Node
        # constructor anyway for the corresponding layer output anyway.
        cpy._keras_history = kt_output._keras_history
        cloned_outputs.append(cpy)
        kt_id_mapping[id(kt_output)] = cpy
    cloned_outputs = tf.nest.pack_sequence_as(outputs, cloned_outputs)

    for node in nodes_to_clone:
        # Clone any keras_tensors to avoid override of _keras_history
        # Or reuse an existing keras_tensor if it has already been cloned.
        output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
        call_args_copy = clone_keras_tensors(node.call_args, kt_id_mapping)
        call_kwargs_copy = clone_keras_tensors(node.call_kwargs, kt_id_mapping)
        # Creating new nodes based on the existing node information.  Node wires
        # itself to inbound and outbound layers.  The Node constructor actually
        # updates this layer's self._inbound_nodes, sets _keras_history on the
        # outputs, and adds itself to the `_outbound_nodes` of the layers that
        # produced the inputs to this layer call.
        node_module.Node(
            node.layer,
            call_args=call_args_copy,
            call_kwargs=call_kwargs_copy,
            outputs=output_copy,
        )
    return cloned_inputs, cloned_outputs


def clone_keras_tensors(args, keras_tensor_mapping):
    """Clone the keras tensors from the inputs.

    For any KerasTensor instance in the `args`, a new copy of KerasTensor will
    be created if it has not been cloned yet (by checking the
    `keras_tensor_mapping`). For any other types, the instance will be
    unchanged. This function is useful for cloning the Nodes since KerasTensor
    can't be reused across the models.

    Args:
      args: A nested structure of objects, which could contain KerasTensor.
      keras_tensor_mapping: A dict contains the ID of original KerasTensor, and
        the cloned KerasTensor instance. The dict will be updated with newly
        copied KerasTensor instances within this method.
    Returns:
      Same structure as inputs, with KerasTensor cloned.
    """
    result = []
    for obj in tf.nest.flatten(args):
        if node_module.is_keras_tensor(obj):
            if id(obj) in keras_tensor_mapping:
                cpy = keras_tensor_mapping[id(obj)]
            else:
                # Create copy of keras_tensor if we haven't done it before
                cpy = _clone_keras_tensor(obj)
                cpy._keras_history = obj._keras_history
                keras_tensor_mapping[id(obj)] = cpy
            result.append(cpy)
        else:
            result.append(obj)
    return tf.nest.pack_sequence_as(args, result)


def _clone_keras_tensor(kt):
    """Create an identical keras_tensor based on the input.

    We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure
    inferred shape are not lost during the copy.

    Args:
      kt: the input KerasTensor.

    Returns:
      An identical copy of the input KerasTensor.
    """
    # Create a scratch graph since we don't intend to use the placeholders.
    with backend._scratch_graph() as scratch_graph:
        with scratch_graph.as_default():
            placeholder = keras_tensor.keras_tensor_to_placeholder(kt)
            return keras_tensor.keras_tensor_from_tensor(placeholder)
