# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for saving/loading function for keras Model."""

import collections

import keras

# Declaring namedtuple()
ModelFn = collections.namedtuple(
    "ModelFn", ["model", "input_shape", "target_shape"]
)


def basic_sequential():
    """Basic sequential model."""
    model = keras.Sequential(
        [
            keras.layers.Dense(3, activation="relu", input_shape=(3,)),
            keras.layers.Dense(2, activation="softmax"),
        ]
    )
    return ModelFn(model, (None, 3), (None, 2))


def basic_sequential_deferred():
    """Sequential model with deferred input shape."""
    model = keras.Sequential(
        [
            keras.layers.Dense(3, activation="relu"),
            keras.layers.Dense(2, activation="softmax"),
        ]
    )
    return ModelFn(model, (None, 3), (None, 2))


def stacked_rnn():
    """Stacked RNN model."""
    inputs = keras.Input((None, 3))
    layer = keras.layers.RNN([keras.layers.LSTMCell(2) for _ in range(3)])
    x = layer(inputs)
    outputs = keras.layers.Dense(2)(x)
    model = keras.Model(inputs, outputs)
    return ModelFn(model, (None, 4, 3), (None, 2))


def lstm():
    """LSTM model."""
    inputs = keras.Input((None, 3))
    x = keras.layers.LSTM(4, return_sequences=True)(inputs)
    x = keras.layers.LSTM(3, return_sequences=True)(x)
    x = keras.layers.LSTM(2, return_sequences=False)(x)
    outputs = keras.layers.Dense(2)(x)
    model = keras.Model(inputs, outputs)
    return ModelFn(model, (None, 4, 3), (None, 2))


def multi_input_multi_output():
    """Multi-input Multi-output model."""
    body_input = keras.Input(shape=(None,), name="body")
    tags_input = keras.Input(shape=(2,), name="tags")

    x = keras.layers.Embedding(10, 4)(body_input)
    body_features = keras.layers.LSTM(5)(x)
    x = keras.layers.concatenate([body_features, tags_input])

    pred_1 = keras.layers.Dense(2, activation="sigmoid", name="priority")(x)
    pred_2 = keras.layers.Dense(3, activation="softmax", name="department")(x)

    model = keras.Model(
        inputs=[body_input, tags_input], outputs=[pred_1, pred_2]
    )
    return ModelFn(model, [(None, 1), (None, 2)], [(None, 2), (None, 3)])


def nested_sequential_in_functional():
    """A sequential model nested in a functional model."""
    inner_model = keras.Sequential(
        [
            keras.layers.Dense(3, activation="relu", input_shape=(3,)),
            keras.layers.Dense(2, activation="relu"),
        ]
    )

    inputs = keras.Input(shape=(3,))
    x = inner_model(inputs)
    outputs = keras.layers.Dense(2, activation="softmax")(x)
    model = keras.Model(inputs, outputs)
    return ModelFn(model, (None, 3), (None, 2))


def seq_to_seq():
    """Sequence to sequence model."""
    num_encoder_tokens = 3
    num_decoder_tokens = 3
    latent_dim = 2
    encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
    encoder = keras.layers.LSTM(latent_dim, return_state=True)
    _, state_h, state_c = encoder(encoder_inputs)
    encoder_states = [state_h, state_c]
    decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
    decoder_lstm = keras.layers.LSTM(
        latent_dim, return_sequences=True, return_state=True
    )
    decoder_outputs, _, _ = decoder_lstm(
        decoder_inputs, initial_state=encoder_states
    )
    decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
    decoder_outputs = decoder_dense(decoder_outputs)
    model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
    return ModelFn(
        model,
        [(None, 2, num_encoder_tokens), (None, 2, num_decoder_tokens)],
        (None, 2, num_decoder_tokens),
    )


def shared_layer_functional():
    """Shared layer in a functional model."""
    main_input = keras.Input(shape=(10,), dtype="int32", name="main_input")
    x = keras.layers.Embedding(output_dim=5, input_dim=4, input_length=10)(
        main_input
    )
    lstm_out = keras.layers.LSTM(3)(x)
    auxiliary_output = keras.layers.Dense(
        1, activation="sigmoid", name="aux_output"
    )(lstm_out)
    auxiliary_input = keras.Input(shape=(5,), name="aux_input")
    x = keras.layers.concatenate([lstm_out, auxiliary_input])
    x = keras.layers.Dense(2, activation="relu")(x)
    main_output = keras.layers.Dense(
        1, activation="sigmoid", name="main_output"
    )(x)
    model = keras.Model(
        inputs=[main_input, auxiliary_input],
        outputs=[main_output, auxiliary_output],
    )
    return ModelFn(model, [(None, 10), (None, 5)], [(None, 1), (None, 1)])


def shared_sequential():
    """Shared sequential model in a functional model."""
    inner_model = keras.Sequential(
        [
            keras.layers.Conv2D(2, 3, activation="relu"),
            keras.layers.Conv2D(2, 3, activation="relu"),
        ]
    )
    inputs_1 = keras.Input((5, 5, 3))
    inputs_2 = keras.Input((5, 5, 3))
    x1 = inner_model(inputs_1)
    x2 = inner_model(inputs_2)
    x = keras.layers.concatenate([x1, x2])
    outputs = keras.layers.GlobalAveragePooling2D()(x)
    model = keras.Model([inputs_1, inputs_2], outputs)
    return ModelFn(model, [(None, 5, 5, 3), (None, 5, 5, 3)], (None, 4))


class MySubclassModel(keras.Model):
    """A subclass model."""

    def __init__(self, input_dim=3):
        super().__init__(name="my_subclass_model")
        self._config = {"input_dim": input_dim}
        self.dense1 = keras.layers.Dense(8, activation="relu")
        self.dense2 = keras.layers.Dense(2, activation="softmax")
        self.bn = keras.layers.BatchNormalization()
        self.dp = keras.layers.Dropout(0.5)

    def call(self, inputs, **kwargs):
        x = self.dense1(inputs)
        x = self.dp(x)
        x = self.bn(x)
        return self.dense2(x)

    def get_config(self):
        return self._config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


def nested_subclassed_model():
    """A subclass model nested in another subclass model."""

    class NestedSubclassModel(keras.Model):
        """A nested subclass model."""

        def __init__(self):
            super().__init__()
            self.dense1 = keras.layers.Dense(4, activation="relu")
            self.dense2 = keras.layers.Dense(2, activation="relu")
            self.bn = keras.layers.BatchNormalization()
            self.inner_subclass_model = MySubclassModel()

        def call(self, inputs):
            x = self.dense1(inputs)
            x = self.bn(x)
            x = self.inner_subclass_model(x)
            return self.dense2(x)

    return ModelFn(NestedSubclassModel(), (None, 3), (None, 2))


def nested_subclassed_in_functional_model():
    """A subclass model nested in a functional model."""
    inner_subclass_model = MySubclassModel()
    inputs = keras.Input(shape=(3,))
    x = inner_subclass_model(inputs)
    x = keras.layers.BatchNormalization()(x)
    outputs = keras.layers.Dense(2, activation="softmax")(x)
    model = keras.Model(inputs, outputs)
    return ModelFn(model, (None, 3), (None, 2))


def nested_functional_in_subclassed_model():
    """A functional model nested in a subclass model."""

    def get_functional_model():
        inputs = keras.Input(shape=(4,))
        x = keras.layers.Dense(4, activation="relu")(inputs)
        x = keras.layers.BatchNormalization()(x)
        outputs = keras.layers.Dense(2)(x)
        return keras.Model(inputs, outputs)

    class NestedFunctionalInSubclassModel(keras.Model):
        """A functional nested in subclass model."""

        def __init__(self):
            super().__init__(name="nested_functional_in_subclassed_model")
            self.dense1 = keras.layers.Dense(4, activation="relu")
            self.dense2 = keras.layers.Dense(2, activation="relu")
            self.inner_functional_model = get_functional_model()

        def call(self, inputs):
            x = self.dense1(inputs)
            x = self.inner_functional_model(x)
            return self.dense2(x)

    return ModelFn(NestedFunctionalInSubclassModel(), (None, 3), (None, 2))


def shared_layer_subclassed_model():
    """Shared layer in a subclass model."""

    class SharedLayerSubclassModel(keras.Model):
        """A subclass model with shared layers."""

        def __init__(self):
            super().__init__(name="shared_layer_subclass_model")
            self.dense = keras.layers.Dense(3, activation="relu")
            self.dp = keras.layers.Dropout(0.5)
            self.bn = keras.layers.BatchNormalization()

        def call(self, inputs):
            x = self.dense(inputs)
            x = self.dp(x)
            x = self.bn(x)
            return self.dense(x)

    return ModelFn(SharedLayerSubclassModel(), (None, 3), (None, 3))


def functional_with_keyword_args():
    """A functional model with keyword args."""
    inputs = keras.Input(shape=(3,))
    x = keras.layers.Dense(4)(inputs)
    x = keras.layers.BatchNormalization()(x)
    outputs = keras.layers.Dense(2)(x)

    model = keras.Model(inputs, outputs, name="m", trainable=False)
    return ModelFn(model, (None, 3), (None, 2))


ALL_MODELS = [
    ("basic_sequential", basic_sequential),
    ("basic_sequential_deferred", basic_sequential_deferred),
    ("stacked_rnn", stacked_rnn),
    ("lstm", lstm),
    ("multi_input_multi_output", multi_input_multi_output),
    ("nested_sequential_in_functional", nested_sequential_in_functional),
    ("seq_to_seq", seq_to_seq),
    ("shared_layer_functional", shared_layer_functional),
    ("shared_sequential", shared_sequential),
    ("nested_subclassed_model", nested_subclassed_model),
    (
        "nested_subclassed_in_functional_model",
        nested_subclassed_in_functional_model,
    ),
    (
        "nested_functional_in_subclassed_model",
        nested_functional_in_subclassed_model,
    ),
    ("shared_layer_subclassed_model", shared_layer_subclassed_model),
    ("functional_with_keyword_args", functional_with_keyword_args),
]


def get_models(exclude_models=None):
    """Get all models excluding the specified ones."""
    models = [model for model in ALL_MODELS if model[0] not in exclude_models]
    return models
