# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""

import copy
import os

import tensorflow.compat.v2 as tf

import keras
from keras import backend
from keras import losses
from keras import optimizers
from keras.engine import base_layer_utils
from keras.optimizers import optimizer_v1
from keras.utils import generic_utils
from keras.utils import version_utils
from keras.utils.io_utils import ask_to_proceed_with_overwrite

# isort: off
from tensorflow.python.platform import tf_logging as logging


def extract_model_metrics(model):
    """Convert metrics from a Keras model `compile` API to dictionary.

    This is used for converting Keras models to Estimators and SavedModels.

    Args:
      model: A `tf.keras.Model` object.

    Returns:
      Dictionary mapping metric names to metric instances. May return `None` if
      the model does not contain any metrics.
    """
    if getattr(model, "_compile_metrics", None):
        # TODO(psv/kathywu): use this implementation in model to estimator flow.
        # We are not using model.metrics here because we want to exclude the
        # metrics added using `add_metric` API.
        return {m.name: m for m in model._compile_metric_functions}
    return None


def model_call_inputs(model, keep_original_batch_size=False):
    """Inspect model to get its input signature.

    The model's input signature is a list with a single (possibly-nested)
    object. This is due to the Keras-enforced restriction that tensor inputs
    must be passed in as the first argument.

    For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
    will have input signature:
    [{'feature1': TensorSpec, 'feature2': TensorSpec}]

    Args:
      model: Keras Model object.
      keep_original_batch_size: A boolean indicating whether we want to keep
        using the original batch size or set it to None. Default is `False`,
        which means that the batch dim of the returned input signature will
        always be set to `None`.

    Returns:
      A tuple containing `(args, kwargs)` TensorSpecs of the model call function
      inputs.
      `kwargs` does not contain the `training` argument.
    """
    input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
    if input_specs is None:
        return None, None
    input_specs = _enforce_names_consistency(input_specs)
    return input_specs


def raise_model_input_error(model):
    if isinstance(model, keras.models.Sequential):
        raise ValueError(
            f"Model {model} cannot be saved because the input shape is not "
            "available. Please specify an input shape either by calling "
            "`build(input_shape)` directly, or by calling the model on actual "
            "data using `Model()`, `Model.fit()`, or `Model.predict()`."
        )

    # If the model is not a `Sequential`, it is intended to be a subclassed
    # model.
    raise ValueError(
        f"Model {model} cannot be saved either because the input shape is not "
        "available or because the forward pass of the model is not defined."
        "To define a forward pass, please override `Model.call()`. To specify "
        "an input shape, either call `build(input_shape)` directly, or call "
        "the model on actual data using `Model()`, `Model.fit()`, or "
        "`Model.predict()`. If you have a custom training step, please make "
        "sure to invoke the forward pass in train step through "
        "`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`."
    )


def trace_model_call(model, input_signature=None):
    """Trace the model call to create a tf.function for exporting a Keras model.

    Args:
      model: A Keras model.
      input_signature: optional, a list of tf.TensorSpec objects specifying the
        inputs to the model.

    Returns:
      A tf.function wrapping the model's call function with input signatures
      set.

    Raises:
      ValueError: if input signature cannot be inferred from the model.
    """
    if input_signature is None:
        if isinstance(model.call, tf.__internal__.function.Function):
            input_signature = model.call.input_signature

    if input_signature:
        model_args = input_signature
        model_kwargs = {}
    else:
        model_args, model_kwargs = model_call_inputs(model)

        if model_args is None:
            raise_model_input_error(model)

    @tf.function
    def _wrapped_model(*args, **kwargs):
        """A concrete tf.function that wraps the model's call function."""
        (args, kwargs,) = model._call_spec.set_arg_value(
            "training", False, args, kwargs, inputs_in_args=True
        )

        with base_layer_utils.call_context().enter(
            model, inputs=None, build_graph=False, training=False, saving=True
        ):
            outputs = model(*args, **kwargs)

        # Outputs always has to be a flat dict.
        output_names = model.output_names  # Functional Model.
        if output_names is None:  # Subclassed Model.
            from keras.engine import compile_utils

            output_names = compile_utils.create_pseudo_output_names(outputs)
        outputs = tf.nest.flatten(outputs)
        return {name: output for name, output in zip(output_names, outputs)}

    return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)


def model_metadata(model, include_optimizer=True, require_config=True):
    """Returns a dictionary containing the model metadata."""
    from keras import __version__ as keras_version
    from keras.optimizers.optimizer_v2 import optimizer_v2

    model_config = {"class_name": model.__class__.__name__}
    try:
        model_config["config"] = model.get_config()
    except NotImplementedError as e:
        if require_config:
            raise e

    metadata = dict(
        keras_version=str(keras_version),
        backend=backend.backend(),
        model_config=model_config,
    )
    if model.optimizer and include_optimizer:
        if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
            logging.warning(
                "TensorFlow optimizers do not "
                "make it possible to access "
                "optimizer attributes or optimizer state "
                "after instantiation. "
                "As a result, we cannot save the optimizer "
                "as part of the model save file. "
                "You will have to compile your model again after loading it. "
                "Prefer using a Keras optimizer instead "
                "(see keras.io/optimizers)."
            )
        elif model._compile_was_called:
            training_config = model._get_compile_args(user_metrics=False)
            training_config.pop("optimizer", None)  # Handled separately.
            metadata["training_config"] = _serialize_nested_config(
                training_config
            )
            if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
                raise NotImplementedError(
                    "Optimizers loaded from a SavedModel cannot be saved. "
                    "If you are calling `model.save` or "
                    "`tf.keras.models.save_model`, "
                    "please set the `include_optimizer` option to `False`. For "
                    "`tf.saved_model.save`, "
                    "delete the optimizer from the model."
                )
            else:
                optimizer_config = {
                    "class_name": generic_utils.get_registered_name(
                        model.optimizer.__class__
                    ),
                    "config": model.optimizer.get_config(),
                }
            metadata["training_config"]["optimizer_config"] = optimizer_config
    return metadata


def should_overwrite(filepath, overwrite):
    """Returns whether the filepath should be overwritten."""
    # If file exists and should not be overwritten.
    if not overwrite and os.path.isfile(filepath):
        return ask_to_proceed_with_overwrite(filepath)
    return True


def compile_args_from_training_config(training_config, custom_objects=None):
    """Return model.compile arguments from training config."""
    if custom_objects is None:
        custom_objects = {}

    with generic_utils.CustomObjectScope(custom_objects):
        optimizer_config = training_config["optimizer_config"]
        optimizer = optimizers.deserialize(optimizer_config)

        # Recover losses.
        loss = None
        loss_config = training_config.get("loss", None)
        if loss_config is not None:
            loss = _deserialize_nested_config(losses.deserialize, loss_config)

        # Recover metrics.
        metrics = None
        metrics_config = training_config.get("metrics", None)
        if metrics_config is not None:
            metrics = _deserialize_nested_config(
                _deserialize_metric, metrics_config
            )

        # Recover weighted metrics.
        weighted_metrics = None
        weighted_metrics_config = training_config.get("weighted_metrics", None)
        if weighted_metrics_config is not None:
            weighted_metrics = _deserialize_nested_config(
                _deserialize_metric, weighted_metrics_config
            )

        sample_weight_mode = (
            training_config["sample_weight_mode"]
            if hasattr(training_config, "sample_weight_mode")
            else None
        )
        loss_weights = training_config["loss_weights"]

    return dict(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics,
        weighted_metrics=weighted_metrics,
        loss_weights=loss_weights,
        sample_weight_mode=sample_weight_mode,
    )


def _deserialize_nested_config(deserialize_fn, config):
    """Deserializes arbitrary Keras `config` using `deserialize_fn`."""

    def _is_single_object(obj):
        if isinstance(obj, dict) and "class_name" in obj:
            return True  # Serialized Keras object.
        if isinstance(obj, str):
            return True  # Serialized function or string.
        return False

    if config is None:
        return None
    if _is_single_object(config):
        return deserialize_fn(config)
    elif isinstance(config, dict):
        return {
            k: _deserialize_nested_config(deserialize_fn, v)
            for k, v in config.items()
        }
    elif isinstance(config, (tuple, list)):
        return [
            _deserialize_nested_config(deserialize_fn, obj) for obj in config
        ]

    raise ValueError(
        "Saved configuration not understood. Configuration should be a "
        f"dictionary, string, tuple or list. Received: config={config}."
    )


def _serialize_nested_config(config):
    """Serialized a nested structure of Keras objects."""

    def _serialize_fn(obj):
        if callable(obj):
            return generic_utils.serialize_keras_object(obj)
        return obj

    return tf.nest.map_structure(_serialize_fn, config)


def _deserialize_metric(metric_config):
    """Deserialize metrics, leaving special strings untouched."""
    from keras import metrics as metrics_module

    if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
        # Do not deserialize accuracy and cross-entropy strings as we have
        # special case handling for these in compile, based on model output
        # shape.
        return metric_config
    return metrics_module.deserialize(metric_config)


def _enforce_names_consistency(specs):
    """Enforces that either all specs have names or none do."""

    def _has_name(spec):
        return spec is None or (hasattr(spec, "name") and spec.name is not None)

    def _clear_name(spec):
        spec = copy.deepcopy(spec)
        if hasattr(spec, "name"):
            spec._name = None
        return spec

    flat_specs = tf.nest.flatten(specs)
    name_inconsistency = any(_has_name(s) for s in flat_specs) and not all(
        _has_name(s) for s in flat_specs
    )

    if name_inconsistency:
        specs = tf.nest.map_structure(_clear_name, specs)
    return specs


def try_build_compiled_arguments(model):
    if (
        not version_utils.is_v1_layer_or_model(model)
        and model.outputs is not None
    ):
        try:
            if not model.compiled_loss.built:
                model.compiled_loss.build(model.outputs)
            if not model.compiled_metrics.built:
                model.compiled_metrics.build(model.outputs, model.outputs)
        except:  # noqa: E722
            logging.warning(
                "Compiled the loaded model, but the compiled metrics have "
                "yet to be built. `model.compile_metrics` will be empty "
                "until you train or evaluate the model."
            )


def is_hdf5_filepath(filepath):
    return (
        filepath.endswith(".h5")
        or filepath.endswith(".keras")
        or filepath.endswith(".hdf5")
    )
