# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This API defines FeatureColumn abstraction."""

# This file was originally under tf/python/feature_column, and was moved to
# Keras package in order to remove the reverse dependency from TF to Keras.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import re

import tensorflow.compat.v2 as tf

from keras.engine.base_layer import Layer
from keras.utils import generic_utils


class _BaseFeaturesLayer(Layer):
    """Base class for DenseFeatures and SequenceFeatures.

    Defines common methods and helpers.

    Args:
      feature_columns: An iterable containing the FeatureColumns to use as
        inputs to your model.
      expected_column_type: Expected class for provided feature columns.
      trainable:  Boolean, whether the layer's variables will be updated via
        gradient descent during training.
      name: Name to give to the DenseFeatures.
      **kwargs: Keyword arguments to construct a layer.

    Raises:
      ValueError: if an item in `feature_columns` doesn't match
        `expected_column_type`.
    """

    def __init__(
        self,
        feature_columns,
        expected_column_type,
        trainable,
        name,
        partitioner=None,
        **kwargs
    ):
        super().__init__(name=name, trainable=trainable, **kwargs)
        self._feature_columns = _normalize_feature_columns(feature_columns)
        self._state_manager = tf.__internal__.feature_column.StateManager(
            self, self.trainable
        )
        self._partitioner = partitioner
        for column in self._feature_columns:
            if not isinstance(column, expected_column_type):
                raise ValueError(
                    "Items of feature_columns must be a {}. "
                    "You can wrap a categorical column with an "
                    "embedding_column or indicator_column. Given: {}".format(
                        expected_column_type, column
                    )
                )

    def build(self, _):
        for column in self._feature_columns:
            with tf.compat.v1.variable_scope(
                self.name, partitioner=self._partitioner
            ):
                with tf.compat.v1.variable_scope(
                    _sanitize_column_name_for_variable_scope(column.name)
                ):
                    column.create_state(self._state_manager)
        super().build(None)

    def _output_shape(self, input_shape, num_elements):
        """Computes expected output shape of the layer or a column's dense tensor.

        Args:
          input_shape: Tensor or array with batch shape.
          num_elements: Size of the last dimension of the output.

        Returns:
          Tuple with output shape.
        """
        raise NotImplementedError("Calling an abstract method.")

    def compute_output_shape(self, input_shape):
        total_elements = 0
        for column in self._feature_columns:
            total_elements += column.variable_shape.num_elements()
        return self._target_shape(input_shape, total_elements)

    def _process_dense_tensor(self, column, tensor):
        """Reshapes the dense tensor output of a column based on expected shape.

        Args:
          column: A DenseColumn or SequenceDenseColumn object.
          tensor: A dense tensor obtained from the same column.

        Returns:
          Reshaped dense tensor.
        """
        num_elements = column.variable_shape.num_elements()
        target_shape = self._target_shape(tf.shape(tensor), num_elements)
        return tf.reshape(tensor, shape=target_shape)

    def _verify_and_concat_tensors(self, output_tensors):
        """Verifies and concatenates the dense output of several columns."""
        _verify_static_batch_size_equality(
            output_tensors, self._feature_columns
        )
        return tf.concat(output_tensors, -1)

    def get_config(self):
        column_configs = [
            tf.__internal__.feature_column.serialize_feature_column(fc)
            for fc in self._feature_columns
        ]
        config = {"feature_columns": column_configs}
        config["partitioner"] = generic_utils.serialize_keras_object(
            self._partitioner
        )

        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

    @classmethod
    def from_config(cls, config, custom_objects=None):
        config_cp = config.copy()
        columns_by_name = {}
        config_cp["feature_columns"] = [
            tf.__internal__.feature_column.deserialize_feature_column(
                c, custom_objects, columns_by_name
            )
            for c in config["feature_columns"]
        ]
        config_cp["partitioner"] = generic_utils.deserialize_keras_object(
            config["partitioner"], custom_objects
        )

        return cls(**config_cp)


def _sanitize_column_name_for_variable_scope(name):
    """Sanitizes user-provided feature names for use as variable scopes."""
    invalid_char = re.compile("[^A-Za-z0-9_.\\-]")
    return invalid_char.sub("_", name)


def _verify_static_batch_size_equality(tensors, columns):
    """Verify equality between static batch sizes.

    Args:
      tensors: iterable of input tensors.
      columns: Corresponding feature columns.

    Raises:
      ValueError: in case of mismatched batch sizes.
    """
    expected_batch_size = None
    for i in range(0, len(tensors)):
        # bath_size is a Dimension object.
        batch_size = tf.compat.v1.Dimension(
            tf.compat.dimension_value(tensors[i].shape[0])
        )
        if batch_size.value is not None:
            if expected_batch_size is None:
                bath_size_column_index = i
                expected_batch_size = batch_size
            elif not expected_batch_size.is_compatible_with(batch_size):
                raise ValueError(
                    "Batch size (first dimension) of each feature must be "
                    "same. Batch size of columns ({}, {}): ({}, {})".format(
                        columns[bath_size_column_index].name,
                        columns[i].name,
                        expected_batch_size,
                        batch_size,
                    )
                )


def _normalize_feature_columns(feature_columns):
    """Normalizes the `feature_columns` input.

    This method converts the `feature_columns` to list type as best as it can.
    In addition, verifies the type and other parts of feature_columns, required
    by downstream library.

    Args:
      feature_columns: The raw feature columns, usually passed by users.

    Returns:
      The normalized feature column list.

    Raises:
      ValueError: for any invalid inputs, such as empty, duplicated names, etc.
    """
    if isinstance(
        feature_columns, tf.__internal__.feature_column.FeatureColumn
    ):
        feature_columns = [feature_columns]

    if isinstance(feature_columns, collections.abc.Iterator):
        feature_columns = list(feature_columns)

    if isinstance(feature_columns, dict):
        raise ValueError("Expected feature_columns to be iterable, found dict.")

    for column in feature_columns:
        if not isinstance(column, tf.__internal__.feature_column.FeatureColumn):
            raise ValueError(
                "Items of feature_columns must be a FeatureColumn. "
                "Given (type {}): {}.".format(type(column), column)
            )
    if not feature_columns:
        raise ValueError("feature_columns must not be empty.")
    name_to_column = {}
    for column in feature_columns:
        if column.name in name_to_column:
            raise ValueError(
                "Duplicate feature column name found for columns: {} "
                "and {}. This usually means that these columns refer to "
                "same base feature. Either one must be discarded or a "
                "duplicated but renamed item must be inserted in "
                "features dict.".format(column, name_to_column[column.name])
            )
        name_to_column[column.name] = column

    return sorted(feature_columns, key=lambda x: x.name)
