# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base Class for TPU Embeddings Mid level APIs."""

import functools
from typing import Any, Dict, Iterable, Optional, Union, Text

from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.trackable import autotrackable
from tensorflow.python.util import nest


class TPUEmbeddingBase(autotrackable.AutoTrackable):
  """The TPUEmbedding Base class.

  This class only contains the basic logic to check the feature config and table
  config for the tpu embedding mid level APIs.
  """

  def __init__(
      self,
      feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable],  # pylint:disable=g-bare-generic
      optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None):  # pylint:disable=protected-access
    """Creates the TPUEmbeddingBase object."""
    self._feature_config = feature_config
    self._output_shapes = []
    for feature in nest.flatten(feature_config):
      self._output_shapes.append(feature.output_shape)
    # Set table order here to the order of the first occurrence of the table in
    # a feature provided by the user. The order of this struct must be fixed
    # to provide the user with deterministic behavior over multiple
    # instantiations.
    self._table_config = []
    for feature in nest.flatten(feature_config):
      if feature.table not in self._table_config:
        self._table_config.append(feature.table)

    # Ensure tables have unique names. Also error check the optimizer as we
    # specifically don't do that in the TableConfig class to allow high level
    # APIs that are built on this to use strings/other classes to represent
    # optimizers (before they are passed to this class).
    table_names = []
    for i, table in enumerate(self._table_config):
      if table.optimizer is None:
        # TODO(bfontain) Should we allow some sort of optimizer merging here?
        table.optimizer = optimizer
      if (table.optimizer is not None and
          not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)):  # pylint: disable=protected-access
        raise ValueError("{} is an unsupported optimizer class. Please pass an "
                         "instance of one of the optimizer classes under "
                         "tf.tpu.experimental.embedding.".format(
                             type(table.optimizer)))
      if table.name is None:
        table.name = "table_{}".format(i)
      if table.name in table_names:
        raise ValueError("Tables must have a unique name. "
                         f"Multiple tables with name {table.name} found.")
      table_names.append(table.name)

    self._built = False

  @property
  def embedding_tables(self):
    """Returns a dict of embedding tables, keyed by `TableConfig`."""
    raise NotImplementedError

  def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig,
                        trainable: bool) -> Dict[Text, tf_variables.Variable]:
    """Create all variables including table variables and slot variables."""
    variable_shape = (table.vocabulary_size, table.dim)

    def getter(name, shape, dtype, initializer, trainable):
      del shape
      # _add_variable_with_custom_getter clears the shape sometimes, so we
      # take the global shape from outside the getter.
      initial_value = functools.partial(
          initializer, variable_shape, dtype=dtype)
      return tf_variables.Variable(
          name=name,
          initial_value=initial_value,
          shape=variable_shape,
          dtype=dtype,
          trainable=trainable)

    def variable_creator(name, initializer, trainable=True):
      # Use add_variable_with_custom_getter here so that we take advantage of
      # the checkpoint loading to allow restore before the variables get
      # created which avoids double initialization.
      return self._add_variable_with_custom_getter(
          name=name,
          initializer=initializer,
          shape=variable_shape,
          dtype=dtypes.float32,
          getter=getter,
          trainable=trainable)

    parameters = variable_creator(
        table.name, table.initializer, trainable=trainable)

    def slot_creator(name, initializer):
      return variable_creator(table.name + "/" + name, initializer, False)

    if table.optimizer is not None:
      slot_vars = table.optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
    else:
      slot_vars = {}
    slot_vars["parameters"] = parameters
    return slot_vars

  def _create_variables_and_slots(self):
    """Create variables and slots variables for TPU embeddings."""
    raise NotImplementedError

  def build(self):
    """Create variables and slots variables for TPU embeddings."""
    if self._built:
      return
    self._variables = self._create_variables_and_slots()
    self._built = True

  def __call__(self, features: Any, weights: Optional[Any] = None) -> Any:
    """Call the mid level api to do embedding lookup."""
    if not self._built:
      self.build()
    return self.embedding_lookup(features, weights)

  def embedding_lookup(self,
                       features: Any,
                       weights: Optional[Any] = None) -> Any:
    """Lookup the embedding table using the input features."""
    raise NotImplementedError
