# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Code for model cloning, plus model-related API entries."""

import tensorflow.compat.v2 as tf

from keras import backend
from keras import metrics as metrics_module
from keras.engine import functional
from keras.engine import sequential
from keras.engine import training
from keras.engine import training_v1
from keras.engine.base_layer import AddMetric
from keras.engine.base_layer import Layer
from keras.engine.input_layer import Input
from keras.engine.input_layer import InputLayer
from keras.optimizers import optimizer_v1
from keras.utils import generic_utils
from keras.utils import version_utils
from keras.utils.generic_utils import CustomObjectScope

# isort: off
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export

# API entries importable from `keras.models`:
Model = training.Model
Sequential = sequential.Sequential


# Callable used to clone a layer with weights preserved.
def share_weights(layer):
    return layer


def _clone_layer(layer):
    return layer.__class__.from_config(layer.get_config())


def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes):
    """Inserts ancillary layers into the model with the proper order."""
    # Sort `AddMetric` layers so they agree with metrics_names.
    metric_layers = [
        layer for layer in ancillary_layers if isinstance(layer, AddMetric)
    ]
    metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name))
    ancillary_layers = [
        layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
    ] + metric_layers
    model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes))


def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
    """Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`.

    Args:
      nodes_by_depth: Provides structure information to create new nodes.
      layer_fn: Function to clone layers.
      layer_map: Map from layers in `model` to new layers.
      tensor_map: Map from tensors in `model` to newly compute tensors.

    Returns:
      A set of new nodes. `layer_map` and `tensor_map` are updated.
    """
    # Iterated over every node in the reference model, in depth order.
    new_nodes = set()
    depth_keys = list(nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
        nodes = nodes_by_depth[depth]
        for node in nodes:
            # Recover the corresponding layer.
            layer = node.outbound_layer

            # Get or create layer.
            if layer not in layer_map:
                new_layer = layer_fn(layer)
                layer_map[layer] = new_layer
                layer = new_layer
            else:
                # Reuse previously cloned layer.
                layer = layer_map[layer]
                # Don't call InputLayer multiple times.
                if isinstance(layer, InputLayer):
                    continue

            # If all previous input tensors are available in tensor_map,
            # then call node.inbound_layer on them.
            if all(
                tensor in tensor_map
                for tensor in tf.nest.flatten(node.input_tensors)
            ):
                # Call layer.
                args = tf.nest.map_structure(
                    lambda t: tensor_map.get(t, t), node.call_args
                )
                kwargs = tf.nest.map_structure(
                    lambda t: tensor_map.get(t, t), node.call_kwargs
                )
                output_tensors = layer(*args, **kwargs)

                # Thread-safe way to keep track of what node was created.
                first_output_tensor = tf.nest.flatten(output_tensors)[0]
                new_nodes.add(
                    layer._inbound_nodes[
                        first_output_tensor._keras_history.node_index
                    ]
                )

                for x, y in zip(
                    tf.nest.flatten(node.output_tensors),
                    tf.nest.flatten(output_tensors),
                ):
                    tensor_map[x] = y
    return new_nodes


def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
    """Clone a functional `Model` instance.

    Model cloning is similar to calling a model on new inputs,
    except that it creates new layers (and thus new weights) instead
    of sharing the weights of the existing layers.

    Input layers are always cloned.

    Args:
        model: Instance of `Model`.
        input_tensors: optional list of input tensors
            to build the model upon. If not provided,
            placeholders will be created.
        layer_fn: callable to be applied on non-input layers in the model. By
            default it clones the layer. Another example is to preserve the
            layer to share the weights. This is required when we create a
            per-replica copy of the model with distribution strategy; we want
            the weights to be shared but still feed inputs separately so we
            create new input layers.

    Returns:
        An instance of `Model` reproducing the behavior
        of the original model, on top of new inputs tensors,
        using newly instantiated weights.

    Raises:
        ValueError: in case of invalid `model` argument value or `layer_fn`
        argument value.
    """
    if not isinstance(model, Model):
        raise ValueError(
            "Expected `model` argument "
            f"to be a `Model` instance. Received: model={model}"
        )
    if isinstance(model, Sequential):
        raise ValueError(
            "Expected `model` argument "
            "to be a functional `Model` instance, "
            f"got a `Sequential` instance instead: {model}"
        )
    if not model._is_graph_network:
        raise ValueError(
            "Expected `model` argument "
            "to be a functional `Model` instance, "
            f"but got a subclassed model instead: {model}"
        )

    new_input_layers = {}  # Cache for created layers.
    if input_tensors is not None:
        # Make sure that all input tensors come from a Keras layer.
        input_tensors = tf.nest.flatten(input_tensors)
        for i, input_tensor in enumerate(input_tensors):
            original_input_layer = model._input_layers[i]

            # Cache input layer. Create a new layer if the tensor is originally
            # not from a Keras layer.
            if not backend.is_keras_tensor(input_tensor):
                name = original_input_layer.name
                input_tensor = Input(
                    tensor=input_tensor, name="input_wrapper_for_" + name
                )
                newly_created_input_layer = input_tensor._keras_history.layer
                new_input_layers[
                    original_input_layer
                ] = newly_created_input_layer
            else:
                new_input_layers[
                    original_input_layer
                ] = input_tensor._keras_history.layer

    if not callable(layer_fn):
        raise ValueError(
            "Expected `layer_fn` argument to be a callable. "
            f"Received: layer_fn={layer_fn}"
        )

    model_configs, created_layers = _clone_layers_and_model_config(
        model, new_input_layers, layer_fn
    )
    # Reconstruct model from the config, using the cloned layers.
    (
        input_tensors,
        output_tensors,
        created_layers,
    ) = functional.reconstruct_from_config(
        model_configs, created_layers=created_layers
    )
    metrics_names = model.metrics_names
    model = Model(input_tensors, output_tensors, name=model.name)
    # Layers not directly tied to outputs of the Model, such as loss layers
    # created in `add_loss` and `add_metric`.
    ancillary_layers = [
        layer for layer in created_layers.values() if layer not in model.layers
    ]
    # TODO(b/162887610): This may need to adjust the inbound node index if the
    # created layers had already been used to define other models.
    if ancillary_layers:
        new_nodes = tf.nest.flatten(
            [
                layer.inbound_nodes[1:]
                if functional._should_skip_first_node(layer)
                else layer.inbound_nodes
                for layer in created_layers.values()
            ]
        )
        _insert_ancillary_layers(
            model, ancillary_layers, metrics_names, new_nodes
        )
    return model


def _clone_layers_and_model_config(model, input_layers, layer_fn):
    """Clones all layers, and returns the model config without serializing layers.

    This function ensures that only the node graph is retrieved when getting the
    model config. The `layer_fn` used to clone layers might not rely on
    `layer.get_config()`, so some custom layers do not define `get_config`.
    Trying to retrieve the config results in errors.

    Args:
      model: A Functional model.
      input_layers: Dictionary mapping input layers in `model` to new input
        layers.
      layer_fn: Function used to clone all non-input layers.

    Returns:
      Model config object, and a dictionary of newly created layers.
    """
    created_layers = {}

    def _copy_layer(layer):
        # Whenever the network config attempts to get the layer serialization,
        # return a dummy dictionary.
        if layer in input_layers:
            created_layers[layer.name] = input_layers[layer]
        elif layer in model._input_layers:
            created_layers[layer.name] = InputLayer(**layer.get_config())
        else:
            created_layers[layer.name] = layer_fn(layer)
        return {}

    config = functional.get_network_config(
        model, serialize_layer_fn=_copy_layer
    )
    return config, created_layers


def _remove_ancillary_layers(model, layer_map, layers):
    """Removes and returns any ancillary layers from `layers` based on `model`.

    Ancillary layers are part of the model topology but not used to compute the
    model outputs, e.g., layers from `add_loss` and `add_metric`.

    Args:
      model: A Keras Model.
      layer_map: A map to from layers in the `model` to those in `layers`.
      layers: A list of all layers.

    Returns:
      Two lists of layers: (1) `layers` with the ancillary layers removed, and
      (2) the ancillary layers.
    """
    ancillary_layers = []  # Additional layers for computing losses and metrics.
    if not model._is_graph_network:
        return layers, ancillary_layers

    # Ancillary layers are those with depth < 0.
    depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0]
    depths.sort(reverse=True)  # Order topologically from inputs to outputs.
    for depth in depths:
        for node in model._nodes_by_depth[depth]:
            ancillary_layers.append(layer_map[node.outbound_layer])

    return [l for l in layers if l not in ancillary_layers], ancillary_layers


def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
    """Clone a `Sequential` model instance.

    Model cloning is similar to calling a model on new inputs,
    except that it creates new layers (and thus new weights) instead
    of sharing the weights of the existing layers.

    Args:
        model: Instance of `Sequential`.
        input_tensors: optional list of input tensors
            to build the model upon. If not provided,
            placeholders will be created.
        layer_fn: callable to be applied on non-input layers in the model. By
            default it clones the layer. Another example is to preserve the
            layer to share the weights. This is required when we create a
            per-replica copy of the model with distribution strategy; we want
            the weights to be shared but still feed inputs separately so we
            create new input layers.

    Returns:
        An instance of `Sequential` reproducing the behavior
        of the original model, on top of new inputs tensors,
        using newly instantiated weights.

    Raises:
        ValueError: in case of invalid `model` argument value or `layer_fn`
        argument value.
    """
    if not isinstance(model, Sequential):
        raise ValueError(
            "Expected `model` argument "
            "to be a `Sequential` model instance. "
            f"Received: model={model}"
        )

    if not callable(layer_fn):
        raise ValueError(
            "Expected `layer_fn` argument to be a callable. "
            f"Received: layer_fn={layer_fn}"
        )

    layers = []  # Layers needed to compute the model's outputs.
    layer_map = {}
    # Ensure that all layers are cloned. The model's layers
    # property will exclude the initial InputLayer (if it exists) in the model,
    # resulting in a different Sequential model structure.
    for layer in model._flatten_layers(include_self=False, recursive=False):
        if isinstance(layer, InputLayer) and input_tensors is not None:
            # If input tensors are provided, the original model's InputLayer is
            # overwritten with a different InputLayer.
            continue
        cloned_layer = (
            _clone_layer(layer)
            if isinstance(layer, InputLayer)
            else layer_fn(layer)
        )
        layers.append(cloned_layer)
        layer_map[layer] = cloned_layer
    layers, ancillary_layers = _remove_ancillary_layers(
        model, layer_map, layers
    )

    if input_tensors is None:
        cloned_model = Sequential(layers=layers, name=model.name)
    elif len(generic_utils.to_list(input_tensors)) != 1:
        raise ValueError(
            "To clone a `Sequential` model, we expect at most one tensor as "
            f"part of `input_tensors`. Received: input_tensors={input_tensors}"
        )
    else:
        # Overwrite the original model's input layer.
        if isinstance(input_tensors, tuple):
            input_tensors = list(input_tensors)
        x = generic_utils.to_list(input_tensors)[0]
        if backend.is_keras_tensor(x):
            origin_layer = x._keras_history.layer
            if isinstance(origin_layer, InputLayer):
                cloned_model = Sequential(
                    layers=[origin_layer] + layers, name=model.name
                )
            else:
                raise ValueError(
                    "Cannot clone a `Sequential` model on top "
                    "of a tensor that comes from a Keras layer "
                    "other than an `InputLayer`. "
                    "Use the Functional API instead. "
                    f"Received: input_tensors={input_tensors}"
                )
        else:
            input_tensor = Input(
                tensor=x, name="input_wrapper_for_" + str(x.name)
            )
            input_layer = input_tensor._keras_history.layer
            cloned_model = Sequential(
                layers=[input_layer] + layers, name=model.name
            )

    if not ancillary_layers:
        return cloned_model

    tensor_map = {}  # Maps tensors from `model` to those in `cloned_model`.
    for depth, cloned_nodes in cloned_model._nodes_by_depth.items():
        nodes = model._nodes_by_depth[depth]
        # This should be safe in a Sequential model. In an arbitrary network,
        # you need to sort using the outbound layer of the node as a key.
        for cloned_node, node in zip(cloned_nodes, nodes):
            if isinstance(cloned_node.output_tensors, list):
                for j, output_tensor in enumerate(cloned_node.output_tensors):
                    tensor_map[node.output_tensors[j]] = output_tensor
            else:
                tensor_map[node.output_tensors] = cloned_node.output_tensors
    # Ancillary nodes have negative depth.
    new_nodes = _make_new_nodes(
        {
            depth: nodes
            for depth, nodes in model._nodes_by_depth.items()
            if depth < 0
        },
        layer_fn,
        layer_map,
        tensor_map,
    )
    _insert_ancillary_layers(
        cloned_model, ancillary_layers, model.metrics_names, new_nodes
    )
    return cloned_model


@keras_export("keras.models.clone_model")
def clone_model(model, input_tensors=None, clone_function=None):
    """Clone a Functional or Sequential `Model` instance.

    Model cloning is similar to calling a model on new inputs,
    except that it creates new layers (and thus new weights) instead
    of sharing the weights of the existing layers.

    Note that
    `clone_model` will not preserve the uniqueness of shared objects within the
    model (e.g. a single variable attached to two distinct layers will be
    restored as two separate variables).

    Args:
        model: Instance of `Model`
            (could be a Functional model or a Sequential model).
        input_tensors: optional list of input tensors or InputLayer objects
            to build the model upon. If not provided,
            new `Input` objects will be created.
        clone_function: Callable to be used to clone each layer in the target
            model (except `InputLayer` instances). It takes as argument the
            layer instance to be cloned, and returns the corresponding layer
            instance to be used in the model copy. If unspecified, this callable
            defaults to the following serialization/deserialization function:
            `lambda layer: layer.__class__.from_config(layer.get_config())`.
            By passing a custom callable, you can customize your copy of the
            model, e.g. by wrapping certain layers of interest (you might want
            to replace all `LSTM` instances with equivalent
            `Bidirectional(LSTM(...))` instances, for example).

    Returns:
      An instance of `Model` reproducing the behavior
      of the original model, on top of new inputs tensors,
      using newly instantiated weights. The cloned model may behave
      differently from the original model if a custom `clone_function`
      modifies the layer.

    Example:

    ```python
    # Create a test Sequential model.
    model = keras.Sequential([
        keras.Input(shape=(728,)),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(1, activation='sigmoid'),
    ])
    # Create a copy of the test model (with freshly initialized weights).
    new_model = clone_model(model)
    ```

    Note that subclassed models cannot be cloned, since their internal
    layer structure is not known. To achieve equivalent functionality
    as `clone_model` in the case of a subclassed model, simply make sure
    that the model class implements `get_config()`
    (and optionally `from_config()`), and call:

    ```python
    new_model = model.__class__.from_config(model.get_config())
    ```
    """
    with generic_utils.DisableSharedObjectScope():
        if clone_function is None:
            clone_function = _clone_layer

        if isinstance(model, Sequential):
            return _clone_sequential_model(
                model, input_tensors=input_tensors, layer_fn=clone_function
            )
        else:
            return _clone_functional_model(
                model, input_tensors=input_tensors, layer_fn=clone_function
            )


# "Clone" a subclassed model by resetting all of the attributes.
def _in_place_subclassed_model_reset(model):
    """Substitute for model cloning that works for subclassed models.

    Subclassed models cannot be cloned because their topology is not
    serializable. To "instantiate" an identical model in a new TF graph, we
    reuse the original model object, but we clear its state.

    After calling this function on a model instance, you can use the model
    instance as if it were a model clone (in particular you can use it in a new
    graph).

    This method clears the state of the input model. It is thus destructive.
    However the original state can be restored fully by calling
    `_in_place_subclassed_model_state_restoration`.

    Args:
      model: Instance of a Keras model created via subclassing.

    Raises:
      ValueError: In case the model uses a subclassed model as inner layer.
    """
    assert (
        not model._is_graph_network
    )  # Only makes sense for subclassed networks
    # Select correct base class for new Model.
    version_utils.swap_class(
        model.__class__,
        training.Model,
        training_v1.Model,
        tf.compat.v1.executing_eagerly_outside_functions(),
    )
    # Retrieve all layers tracked by the model as well as their attribute names
    attributes_cache = {}
    for name in dir(model):
        # Skip attrs that track other trackables.
        if name == "submodules" or name == "_self_tracked_trackables":
            continue

        try:
            value = getattr(model, name)
        except (AttributeError, ValueError, TypeError):
            continue
        if isinstance(value, Layer):
            attributes_cache[name] = value
            assert value in model.layers
            if hasattr(value, "layers") and value.layers:
                raise ValueError(
                    "We do not support the use of nested layers "
                    "in `model_to_estimator` at this time. Found nested "
                    f"layer: {value}"
                )
        elif isinstance(value, (list, tuple)) and name not in (
            "layers",
            "_layers",
            "metrics",
            "_compile_metric_functions",
            "_output_loss_metrics",
        ):
            # Handle case: list/tuple of layers (also tracked by the Network
            # API).
            if value and all(isinstance(val, Layer) for val in value):
                raise ValueError(
                    "We do not support the use of list-of-layers "
                    "attributes in subclassed models used with "
                    "`model_to_estimator` at this time. Found list "
                    f"model: {name}"
                )

    # Replace layers on the model with fresh layers
    layers_to_names = {value: key for key, value in attributes_cache.items()}
    original_layers = list(
        model._flatten_layers(include_self=False, recursive=False)
    )
    setattr_tracking = model._setattr_tracking
    model._setattr_tracking = False
    model._self_tracked_trackables = []
    for layer in original_layers:  # We preserve layer order.
        config = layer.get_config()
        # This will not work for nested subclassed models used as layers.
        # This would be theoretically possible to support, but would add
        # complexity. Only do it if users complain.
        if isinstance(layer, training.Model) and not layer._is_graph_network:
            raise ValueError(
                "We do not support the use of nested subclassed models "
                "in `model_to_estimator` at this time. Found nested "
                f"model: {layer}"
            )
        fresh_layer = layer.__class__.from_config(config)
        name = layers_to_names[layer]
        setattr(model, name, fresh_layer)
        model._self_tracked_trackables.append(fresh_layer)

    # Cache original model build attributes (in addition to layers)
    if (
        not hasattr(model, "_original_attributes_cache")
        or model._original_attributes_cache is None
    ):
        if model.built:
            attributes_to_cache = [
                "inputs",
                "outputs",
                "total_loss",
                "optimizer",
                "train_function",
                "test_function",
                "predict_function",
                "_training_endpoints",
                "_collected_trainable_weights",
                "_feed_inputs",
                "_feed_input_names",
                "_feed_input_shapes",
            ]
            for name in attributes_to_cache:
                attributes_cache[name] = getattr(model, name)
    model._original_attributes_cache = attributes_cache
    _reset_build_compile_trackers(model)
    model._setattr_tracking = setattr_tracking


def _reset_build_compile_trackers(model):
    """Reset state trackers for model.

    Note that we do not actually zero out attributes such as optimizer,
    but instead rely on the expectation that all of the attrs will be
    over-written on calling build/compile/etc. This is somewhat fragile,
    insofar as we check elsewhere for the presence of these attributes as
    evidence of having been built/compiled/etc. Pending a better way to do this,
    we reset key attributes here to allow building and compiling.

    Args:
      model: the model that is being reset
    """
    # Reset build state
    model.built = False
    model.inputs = None
    model.outputs = None
    # Reset compile state
    model._is_compiled = False
    if not tf.compat.v1.executing_eagerly_outside_functions():
        model._v1_compile_was_called = False
    model.optimizer = None


@keras_export(
    "keras.__internal__.models.in_place_subclassed_model_state_restoration",
    v1=[],
)
def in_place_subclassed_model_state_restoration(model):
    """Restores the original state of a model after it was "reset".

    This undoes this action of `_in_place_subclassed_model_reset`, which is
    called in `clone_and_build_model` if `in_place_reset` is set to True.

    Args:
      model: Instance of a Keras model created via subclassing, on which
        `_in_place_subclassed_model_reset` was previously called.
    """
    assert not model._is_graph_network
    # Restore layers and build attributes
    if (
        hasattr(model, "_original_attributes_cache")
        and model._original_attributes_cache is not None
    ):
        # Models have sticky attribute assignment, so we want to be careful to
        # add back the previous attributes and track Layers by their original
        # names without adding dependencies on "utility" attributes which Models
        # exempt when they're constructed.
        setattr_tracking = model._setattr_tracking
        model._setattr_tracking = False
        model._self_tracked_trackables = []
        for name, value in model._original_attributes_cache.items():
            setattr(model, name, value)
            if isinstance(value, Layer):
                model._self_tracked_trackables.append(value)
        model._original_attributes_cache = None
        model._setattr_tracking = setattr_tracking
    else:
        # Restore to the state of a never-called model.
        _reset_build_compile_trackers(model)


@keras_export("keras.__internal__.models.clone_and_build_model", v1=[])
def clone_and_build_model(
    model,
    input_tensors=None,
    target_tensors=None,
    custom_objects=None,
    compile_clone=True,
    in_place_reset=False,
    optimizer_iterations=None,
    optimizer_config=None,
):
    """Clone a `Model` and build/compile it with the same settings used before.

    This function can be run in the same graph or in a separate graph from the
    model. When using a separate graph, `in_place_reset` must be `False`.

    Note that, currently, the clone produced from this function may not work
    with TPU DistributionStrategy. Try at your own risk.

    Args:
      model: `tf.keras.Model` object. Can be Functional, Sequential, or
        sub-classed.
      input_tensors: Optional list or dictionary of input tensors to build the
        model upon. If not provided, placeholders will be created.
      target_tensors: Optional list of target tensors for compiling the model.
        If not provided, placeholders will be created.
      custom_objects: Optional dictionary mapping string names to custom classes
        or functions.
      compile_clone: Boolean, whether to compile model clone (default `True`).
      in_place_reset: Boolean, whether to reset the model in place. Only used if
        the model is a subclassed model. In the case of a subclassed model,
        this argument must be set to `True` (default `False`). To restore the
        original model, use the function
        `in_place_subclassed_model_state_restoration(model)`.
      optimizer_iterations: An iterations variable that will be incremented by
        the optimizer if the clone is compiled. This argument is used when a
        Keras model is cloned into an Estimator model function, because
        Estimators create their own global step variable.
      optimizer_config: Optimizer config dictionary or list of dictionary
        returned from `get_config()`. This argument should be defined if
        `clone_and_build_model` is called in a different graph or session from
        the original model, and the optimizer is an instance of `OptimizerV2`.

    Returns:
      Clone of the model.

    Raises:
      ValueError: Cloning fails in the following cases
        - cloning a subclassed model with `in_place_reset` set to False.
        - compiling the clone when the original model has not been compiled.
    """
    # Grab optimizer now, as we reset-in-place for subclassed models, but
    # want to maintain access to the original optimizer.
    orig_optimizer = model.optimizer
    if compile_clone and not orig_optimizer:
        raise ValueError(
            "Error when cloning model: `compile_clone` was set to True, but "
            f"the original model has not been compiled. Received: model={model}"
        )

    if compile_clone:
        compile_args = model._get_compile_args()
        # Allows this method to be robust to switching graph and eager classes.
        model._get_compile_args = lambda: compile_args

    with CustomObjectScope(custom_objects or {}):
        if model._is_graph_network:
            clone = clone_model(model, input_tensors=input_tensors)
        elif isinstance(model, Sequential):
            clone = clone_model(model, input_tensors=input_tensors)
            if (
                not clone._is_graph_network
                and model._build_input_shape is not None
            ):
                if tf.compat.v1.executing_eagerly_outside_functions():
                    clone.build(model._build_input_shape)
                else:
                    clone._set_inputs(
                        backend.placeholder(
                            model._build_input_shape,
                            dtype=model.inputs[0].dtype,
                        )
                    )
        else:
            try:
                # Prefer cloning the model if serial/deserial logic is
                # implemented for subclassed model.
                clone = model.__class__.from_config(model.get_config())
            except NotImplementedError:
                logging.warning(
                    "This model is a subclassed model. Please implement "
                    "`get_config` and `from_config` to better support "
                    "cloning the model."
                )
                if not in_place_reset:
                    raise ValueError(
                        f"This model ({model}) is a subclassed model. "
                        "Such a model cannot be cloned, but there is a "
                        "workaround where the model is reset in-place. "
                        "To use this, please set the "
                        "argument `in_place_reset` to `True`. This will reset "
                        "the attributes in the original model. "
                        "To restore the attributes, call "
                        "`in_place_subclassed_model_state_restoration(model)`."
                    )
                clone = model
                _in_place_subclassed_model_reset(clone)
            if input_tensors is not None:
                if (
                    isinstance(input_tensors, (list, tuple))
                    and len(input_tensors) == 1
                ):
                    input_tensors = input_tensors[0]
                clone._set_inputs(input_tensors)

    if compile_clone:
        if isinstance(orig_optimizer, optimizer_v1.TFOptimizer):
            optimizer = optimizer_v1.TFOptimizer(
                orig_optimizer.optimizer, optimizer_iterations
            )
            backend.track_tf_optimizer(optimizer)
        else:
            if not isinstance(orig_optimizer, (tuple, list)):
                orig_optimizer = [orig_optimizer]
            if optimizer_config is None:
                optimizer = [
                    opt.__class__.from_config(opt.get_config())
                    for opt in orig_optimizer
                ]
            elif isinstance(optimizer_config, dict):
                optimizer = [
                    orig_optimizer[0].__class__.from_config(optimizer_config)
                ]
            else:
                # optimizer config is list of dict, same order as
                # orig_optimizer.
                optimizer = [
                    opt.__class__.from_config(opt_config)
                    for (opt, opt_config) in zip(
                        orig_optimizer, optimizer_config
                    )
                ]
            if optimizer_iterations is not None:
                for opt in optimizer:
                    opt.iterations = optimizer_iterations

            if len(optimizer) == 1:
                optimizer = optimizer[0]

        compile_args["optimizer"] = optimizer
        if target_tensors is not None:
            compile_args["target_tensors"] = target_tensors
        # Ensure Metric objects in new model are separate from existing model.
        compile_args["metrics"] = metrics_module.clone_metrics(
            compile_args["metrics"]
        )
        compile_args["weighted_metrics"] = metrics_module.clone_metrics(
            compile_args["weighted_metrics"]
        )
        clone.compile(**compile_args)

    return clone
