# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU embedding APIs."""

import collections
import copy
import math
import re

from typing import Optional

import six

from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util.tf_export import tf_export

TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE


# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
#  as AdagradParameters etc instead of learning_rate.
class TableConfig(
    collections.namedtuple('TableConfig', [
        'vocabulary_size',
        'dimension',
        'initializer',
        'combiner',
        'hot_id_replication',
        'learning_rate',
        'learning_rate_fn',
        'optimization_parameters',
    ])):
  """Embedding table configuration."""

  def __new__(cls,
              vocabulary_size,
              dimension,
              initializer=None,
              combiner='mean',
              hot_id_replication=False,
              learning_rate=None,
              learning_rate_fn=None,
              optimization_parameters=None):
    """Embedding table configuration.

    Args:
      vocabulary_size: Number of vocabulary (/rows) in the table.
      dimension: The embedding dimension.
      initializer: A variable initializer function to be used in embedding
        variable initialization. If not specified, defaults to
        `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard
        deviation `1/sqrt(dimension)`.
      combiner: A string specifying how to reduce if there are multiple entries
        in a single row. Currently 'mean', 'sqrtn', 'sum' and None are
        supported, with 'mean' the default. 'sqrtn' often achieves good
        accuracy, in particular with bag-of-words columns. For more information,
        see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
        than sparse tensors.
      hot_id_replication: If true, enables hot id replication, which can make
        embedding lookups faster if there are some hot rows in the table.
      learning_rate: float, static learning rate for this table. If
        learning_rate and learning_rate_fn are both `None`, static learning rate
        as specified in local `optimization_parameters` will be used. In case
        local `optimization_parameters` is `None`, global
        `optimization_parameters` in `TPUEmbedding` constructor will be used.
        `learning_rate_fn` must be `None` if `learning_rate` is not `None.
      learning_rate_fn: string, use dynamic learning rate given by the function.
        This function will be passed the current global step. If learning_rate
        and learning_rate_fn are both `None`, static learning rate as specified
        in `optimization_parameters` is used. `learning_rate` must be `None` if
        `learning_rate_fn` is not `None.
      optimization_parameters: `AdagradParameters`, `AdamParameters`,
        `Stochasticgradientdescentparameters`. Specifies table level optimizer.
        If it's `None` global optimizer in `TPUEmbedding` constructor is used.

    Returns:
      `TableConfig`.

    Raises:
      ValueError: if `vocabulary_size` is not positive integer.
      ValueError: if `dimension` is not positive integer.
      ValueError: if `initializer` is specified and is not callable.
      ValueError: if `combiner` is not supported.
      ValueError: if `learning_rate` and `learning_rate_fn` are both not
        `None`.
    """
    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
      raise ValueError(f'vocabulary_size must >= 1. '
                       f'Received: {vocabulary_size}.')

    if not isinstance(dimension, int) or dimension < 1:
      raise ValueError(
          f'dimension must be a positive int. Received: {dimension}.')

    if (initializer is not None) and (not callable(initializer)):
      raise ValueError(f'initializer must be callable if specified. '
                       f'Received: {initializer}.')
    if initializer is None:
      initializer = init_ops.truncated_normal_initializer(
          mean=0.0, stddev=1 / math.sqrt(dimension))

    if combiner not in ('mean', 'sum', 'sqrtn', None):
      raise ValueError(f'combiner must be "mean", "sum", "sqrtn" or None. '
                       f'Received: {combiner}.')

    if learning_rate is not None and learning_rate_fn is not None:
      raise ValueError('At most one of learning_rate and learning_rate_fn '
                       'can be None. Received: {} and {}'.format(
                           learning_rate, learning_rate_fn))

    if optimization_parameters is not None:
      if not isinstance(optimization_parameters, _OptimizationParameters):
        raise ValueError(f'`optimization_parameters` must inherit from '
                         f'`_OptimizationParameters`. '
                         f'Received: `type(optimization_parameters)`='
                         f'{type(optimization_parameters)}.')

    return super(TableConfig,
                 cls).__new__(cls, vocabulary_size, dimension, initializer,
                              combiner, hot_id_replication, learning_rate,
                              learning_rate_fn, optimization_parameters)


class FeatureConfig(
    collections.namedtuple('FeatureConfig',
                           ['table_id', 'max_sequence_length', 'weight_key'])):
  """Feature configuration."""

  def __new__(cls, table_id, max_sequence_length=0, weight_key=None):
    """Feature configuration.

    Args:
      table_id: Which table the feature is uses for embedding lookups.
      max_sequence_length: If positive, the feature is a sequence feature with
        the corresponding maximum sequence length. If the sequence is longer
        than this, it will be truncated. If 0, the feature is not a sequence
        feature.
      weight_key: If using weights for the combiner, this key specifies which
        input feature contains the weights.

    Returns:
      `FeatureConfig`.

    Raises:
      ValueError: if `max_sequence_length` non-integer or negative.
    """
    if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
      raise ValueError(f'max_sequence_length must be zero or a positive int, '
                       f'got {max_sequence_length}.')

    return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length,
                                             weight_key)


class EnqueueData(
    collections.namedtuple(
        'EnqueueData',
        ['embedding_indices', 'sample_indices', 'aggregation_weights'])):
  """Data to be enqueued through generate_enqueue_ops()."""

  def __new__(cls,
              embedding_indices,
              sample_indices=None,
              aggregation_weights=None):
    """Data to be enqueued through generate_enqueue_ops().

    Args:
      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
        corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
        and int64 are allowed and will be converted to int32 internally.
      sample_indices: A rank 2 Tensor specifying the training example to which
        the corresponding embedding_indices and aggregation_weights values
        belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
        If it is None, we assume each embedding_indices belongs to a different
        sample. Both int32 and int64 are allowed and will be converted to int32
        internally.
      aggregation_weights: A rank 1 Tensor containing aggregation weights. It
        corresponds to sp_weights.values in embedding_lookup_sparse(). If it is
        None, we assume all weights are 1. Both float32 and float64 are allowed
        and will be converted to float32 internally.

    Returns:
      An EnqueueData tuple.

    """
    return super(EnqueueData, cls).__new__(cls, embedding_indices,
                                           sample_indices, aggregation_weights)

  @staticmethod
  def from_sparse_tensor(sp_tensor, weights=None):
    return EnqueueData(
        sp_tensor.values,
        sp_tensor.indices,
        aggregation_weights=weights.values if weights is not None else None)


class RaggedEnqueueData(
    collections.namedtuple(
        'RaggedEnqueueData',
        ['embedding_indices', 'row_splits', 'aggregation_weights'])):
  """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""

  def __new__(cls,
              embedding_indices,
              row_splits=None,
              aggregation_weights=None):
    """Data to be enqueued through generate_enqueue_ops().

    Args:
      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
        corresponds to ids.values in embedding_lookup(), when ids is a
        RaggedTensor. Both int32 and int64 are allowed and will be converted to
        int32 internally.
      row_splits: A rank 1 Tensor specifying the length of  the break points for
        splitting embedding_indices and aggregation_weights. It corresponds to
        ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
        int32 and int64 are allowed and will be converted to int32 internally.
      aggregation_weights: A rank 1 Tensor containing per training example
        aggregation weights. It corresponds to the values field of a
        RaggedTensor with the same row_splits as ids in embedding_lookup(), when
        ids is a RaggedTensor.

    Returns:
      An RaggedEnqueueData tuple.

    """
    return super(RaggedEnqueueData,
                 cls).__new__(cls, embedding_indices, row_splits,
                              aggregation_weights)

  @staticmethod
  def from_ragged_tensor(rg_tensor, weights=None):
    return RaggedEnqueueData(
        rg_tensor.values,
        rg_tensor.row_splits,
        aggregation_weights=weights.values if weights is not None else None)


def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
  """Convenient function for generate_enqueue_ops().

  Args:
    sp_tensors_list: a list of dictionary mapping from string of feature names
      to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
      same host should be contiguous on the list.

  Returns:
    enqueue_datas_list: a list of dictionary mapping from string
      of feature names to EnqueueData. Each dictionary is for one
      TPU core. Dictionaries for the same host should be contiguous
      on the list.

  """
  enqueue_datas_list = []
  for sp_tensors in sp_tensors_list:
    enqueue_datas = collections.OrderedDict(
        (k, EnqueueData.from_sparse_tensor(v))
        for k, v in six.iteritems(sp_tensors))
    enqueue_datas_list.append(enqueue_datas)
  return enqueue_datas_list


def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
  """Convenient function for generate_enqueue_ops().

  Args:
    rg_tensors_list: a list of dictionary mapping from string of feature names
      to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
      same host should be contiguous on the list.

  Returns:
    enqueue_datas_list: a list of dictionary mapping from string
      of feature names to RaggedEnqueueData. Each dictionary is for one
      TPU core. Dictionaries for the same host should be contiguous
      on the list.

  """
  enqueue_datas_list = []
  for rg_tensors in rg_tensors_list:
    enqueue_datas = collections.OrderedDict(
        (k, RaggedEnqueueData.from_ragged_tensor(v))
        for k, v in six.iteritems(rg_tensors))
    enqueue_datas_list.append(enqueue_datas)
  return enqueue_datas_list


AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames',
                                               ['m', 'v'])

AdagradSlotVariableNames = collections.namedtuple('AdagradSlotVariableNames',
                                                  ['accumulator'])

MomentumSlotVariableNames = collections.namedtuple('MomentumSlotVariableNames',
                                                   ['momenta'])

AdagradMomentumSlotVariableNames = collections.namedtuple(
    'AdagradMomentumSlotVariableNames', ['accumulator', 'momenta'])

RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames',
                                                  ['ms', 'mom'])

ProximalAdagradSlotVariableNames = collections.namedtuple(
    'ProximalAdagradSlotVariableNames', ['accumulator'])

FtrlSlotVariableNames = collections.namedtuple('FtrlSlotVariableNames',
                                               ['accumulator', 'linear'])

ProximalYogiSlotVariableNames = collections.namedtuple(
    'ProximalYogiSlotVariableNames', ['v', 'm'])

FrequencyEstimatorSlotVariableNames = collections.namedtuple(
    'FrequencyEstimatorSlotVariableNames', ['last_hit_step'])

AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v'])

MomentumSlotVariables = collections.namedtuple('MomentumSlotVariables',
                                               ['momenta'])

AdagradMomentumSlotVariables = collections.namedtuple(
    'AdagradMomentumSlotVariables', ['accumulator', 'momenta'])

RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables',
                                              ['ms', 'mom'])

AdagradSlotVariables = collections.namedtuple('AdagradSlotVariables',
                                              ['accumulator'])

ProximalAdagradSlotVariables = collections.namedtuple(
    'ProximalAdagradSlotVariables', ['accumulator'])

FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
                                          ['accumulator', 'linear'])

ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
                                                   ['v', 'm'])

FrequencyEstimatorSlotVariables = collections.namedtuple(
    'FrequencyEstimatorSlotVariables', ['last_hit_step'])

VariablesAndOps = collections.namedtuple('VariablesAndOps', [
    'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
    'retrieve_ops'
])


class _OptimizationParameters(object):
  """Parameters common to all optimizations."""

  def __init__(
      self,
      learning_rate: float,
      use_gradient_accumulation: bool,
      clip_weight_min: Optional[float],
      clip_weight_max: Optional[float],
      weight_decay_factor: Optional[float],
      multiply_weight_decay_factor_by_learning_rate: Optional[bool],
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    self.learning_rate = learning_rate
    self.use_gradient_accumulation = use_gradient_accumulation
    self.clip_weight_min = clip_weight_min
    self.clip_weight_max = clip_weight_max
    self.weight_decay_factor = weight_decay_factor
    self.multiply_weight_decay_factor_by_learning_rate = (
        multiply_weight_decay_factor_by_learning_rate)
    self.clip_gradient_min = clip_gradient_min
    self.clip_gradient_max = clip_gradient_max

    if not use_gradient_accumulation and (clip_gradient_min is not None or
                                          clip_gradient_max is not None):
      raise ValueError('When using gradient clipping limits, gradient  '
                       'accumulation must be enabled.')


@tf_export(v1=['tpu.experimental.AdagradParameters'])
class AdagradParameters(_OptimizationParameters):
  """Optimization parameters for Adagrad with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1),
          ...))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      initial_accumulator: float = 0.1,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for Adagrad.

    Args:
      learning_rate: used for updating embedding table.
      initial_accumulator: initial accumulator for Adagrad.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(AdagradParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    if initial_accumulator <= 0:
      raise ValueError(
          f'Adagrad initial_accumulator must be greater than zero. '
          f'Received: {initial_accumulator}.')
    self.initial_accumulator = initial_accumulator


class AdagradMomentumParameters(_OptimizationParameters):
  """Optimization parameters for Adagrad + Momentum with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=tf.tpu.experimental.AdagradMomentumParameters(0.1),
          ...))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      momentum: float,
      use_nesterov: bool = False,
      exponent: float = 2,
      beta2: float = 1,
      epsilon: float = 1e-10,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for Adagrad.

    Args:
      learning_rate: used for updating embedding table.
      momentum: Moving average parameter for the momentum accumulator.
      use_nesterov: Whether to use the Nesterov variant of momentum. See
        Sutskever et al., 2013.
      exponent: Exponent for the Adagrad accumulator.
      beta2: Moving average parameter for the Adagrad accumulator.
      epsilon: initial accumulator for Adagrad accumulator.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(AdagradMomentumParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    if epsilon <= 0:
      raise ValueError('Adagrad momentum: epsilon must be positive')
    if exponent <= 0:
      raise ValueError('Adagrad momentum: Precondition exponent must >0')
    self.momentum = momentum
    self.use_nesterov = use_nesterov
    self.exponent = exponent
    self.beta2 = beta2
    self.epsilon = epsilon


class ProximalAdagradParameters(_OptimizationParameters):
  """Optimization parameters for ProximalAdagrad with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.
  """

  def __init__(
      self,
      learning_rate: float,
      initial_accumulator: float = 0.1,
      l1_regularization_strength: float = 0.0,
      l2_regularization_strength: float = 0.0,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for Adagrad.

    Args:
      learning_rate: used for updating embedding table.
      initial_accumulator: initial accumulator for Adagrad.
      l1_regularization_strength: A float value, must be greater than or equal
        to zero.
      l2_regularization_strength: A float value, must be greater than or equal
        to zero.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details. for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(ProximalAdagradParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    if initial_accumulator <= 0:
      raise ValueError(f'Adagrad initial_accumulator must be positive. '
                       f'Received: {initial_accumulator}.')
    if l1_regularization_strength < 0.:
      raise ValueError('l1_regularization_strength must be greater than or '
                       'equal to 0. got {}.'.format(l1_regularization_strength))

    if l2_regularization_strength < 0.:
      raise ValueError('l2_regularization_strength must be greater than or '
                       'equal to 0. got {}.'.format(l2_regularization_strength))

    self.initial_accumulator = initial_accumulator
    self.l1_regularization_strength = l1_regularization_strength
    self.l2_regularization_strength = l2_regularization_strength


@tf_export(v1=['tpu.experimental.AdamParameters'])
class AdamParameters(_OptimizationParameters):
  """Optimization parameters for Adam with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=tf.tpu.experimental.AdamParameters(0.1),
          ...))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      beta1: float = 0.9,
      beta2: float = 0.999,
      epsilon: float = 1e-08,
      lazy_adam: bool = True,
      sum_inside_sqrt: bool = True,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for Adam.

    Args:
      learning_rate: a floating point value. The learning rate.
      beta1: A float value. The exponential decay rate for the 1st moment
        estimates.
      beta2: A float value. The exponential decay rate for the 2nd moment
        estimates.
      epsilon: A small constant for numerical stability.
      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See
        `optimization_parameters.proto` for details.
      sum_inside_sqrt: This improves training speed. Please see
        `optimization_parameters.proto` for details.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(AdamParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    if beta1 < 0. or beta1 >= 1.:
      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
    if beta2 < 0. or beta2 >= 1.:
      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
    if epsilon <= 0.:
      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
    if not use_gradient_accumulation and not lazy_adam:
      raise ValueError(
          'When disabling Lazy Adam, gradient accumulation must be used.')

    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon
    self.lazy_adam = lazy_adam
    self.sum_inside_sqrt = sum_inside_sqrt


@tf_export(v1=['tpu.experimental.FtrlParameters'])
class FtrlParameters(_OptimizationParameters):
  """Optimization parameters for Ftrl with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1),
          ...))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      learning_rate_power: float = -0.5,
      initial_accumulator_value: float = 0.1,
      l1_regularization_strength: float = 0.0,
      l2_regularization_strength: float = 0.0,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      multiply_linear_by_learning_rate: bool = False,
      beta: float = 0,
      allow_zero_accumulator: bool = False,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for Ftrl.

    Implements FTRL as described in the following [paper](
    https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)

    Args:
      learning_rate: a floating point value. The learning rate.
      learning_rate_power: A float value, must be less or equal to zero.
        Controls how the learning rate decreases during training. Use zero for a
        fixed learning rate. See section 3.1 in the
        [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
      initial_accumulator_value: The starting value for accumulators. Only zero
        or positive values are allowed.
      l1_regularization_strength: A float value, must be greater than or equal
        to zero.
      l2_regularization_strength: A float value, must be greater than or equal
        to zero.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details. for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      multiply_linear_by_learning_rate: When true, multiplies the usages of the
        linear slot in the weight update by the learning rate. This is useful
        when ramping up learning rate from 0 (which would normally produce
        NaNs).
      beta: The beta parameter for FTRL.
      allow_zero_accumulator: Changes the implementation of the square root to
        allow for the case of initial_accumulator_value being zero. This will
        cause a slight performance drop.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(FtrlParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    if learning_rate_power > 0.:
      raise ValueError('learning_rate_power must be less than or equal to 0. '
                       'got {}.'.format(learning_rate_power))

    if initial_accumulator_value < 0.:
      raise ValueError('initial_accumulator_value must be greater than or equal'
                       ' to 0. got {}.'.format(initial_accumulator_value))

    if l1_regularization_strength < 0.:
      raise ValueError('l1_regularization_strength must be greater than or '
                       'equal to 0. got {}.'.format(l1_regularization_strength))

    if l2_regularization_strength < 0.:
      raise ValueError('l2_regularization_strength must be greater than or '
                       'equal to 0. got {}.'.format(l2_regularization_strength))

    self.learning_rate_power = learning_rate_power
    self.initial_accumulator_value = initial_accumulator_value
    self.initial_linear_value = 0.0
    self.l1_regularization_strength = l1_regularization_strength
    self.l2_regularization_strength = l2_regularization_strength
    self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate
    self.beta = beta
    self.allow_zero_accumulator = allow_zero_accumulator


class ProximalYogiParameters(_OptimizationParameters):
  # pylint: disable=line-too-long
  """Optimization parameters for Proximal Yogi with TPU embeddings.

  Implements the Yogi optimizer as described in
  [Adaptive Methods for Nonconvex
  Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.
  """

  # pylint: enable=line-too-long

  def __init__(
      self,
      learning_rate: float = 0.01,
      beta1: float = 0.9,
      beta2: float = 0.999,
      epsilon: float = 1e-3,
      l1_regularization_strength: float = 0.0,
      l2_regularization_strength: float = 0.0,
      initial_accumulator_value: float = 1e-6,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for Proximal Yogi.

    Args:
      learning_rate: a floating point value. The learning rate.
      beta1: A float value. The exponential decay rate for the 1st moment
        estimates.
      beta2: A float value. The exponential decay rate for the 2nd moment
        estimates.
      epsilon: A small constant for numerical stability.
      l1_regularization_strength: A float value, must be greater than or equal
        to zero.
      l2_regularization_strength: A float value, must be greater than or equal
        to zero.
      initial_accumulator_value: The starting value for accumulators. Only zero
        or positive values are allowed.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details. for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(ProximalYogiParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    if beta1 < 0. or beta1 >= 1.:
      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
    if beta2 < 0. or beta2 >= 1.:
      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
    if epsilon <= 0.:
      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
    if l1_regularization_strength < 0.:
      raise ValueError('l1_regularization_strength must be greater than or '
                       'equal to 0. got {}.'.format(l1_regularization_strength))
    if l2_regularization_strength < 0.:
      raise ValueError('l2_regularization_strength must be greater than or '
                       'equal to 0. got {}.'.format(l2_regularization_strength))

    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon
    self.l1_regularization_strength = l1_regularization_strength
    self.l2_regularization_strength = l2_regularization_strength
    self.initial_accumulator_value = initial_accumulator_value


class MomentumParameters(_OptimizationParameters):
  """Optimization parameters for Momentum with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
          ...))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      momentum: float,
      use_nesterov: bool = False,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for momentum.

    Args:
      learning_rate: a floating point value. The learning rate.
      momentum: a floating point value.  The momentum.
      use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al.,
        2013). This implementation always computes gradients at the value of the
        variable(s) passed to the optimizer. Using Nesterov Momentum makes the
        variable(s) track the values called `theta_t + mu*v_t` in the paper.
        This implementation is an approximation of the original formula, valid
        for high values of momentum. It will compute the "adjusted gradient" in
        NAG by assuming that the new gradient will be estimated by the current
        average gradient plus the product of momentum and the change in the
        average gradient.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(MomentumParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    self.momentum = momentum
    self.use_nesterov = use_nesterov


class RMSPropParameters(_OptimizationParameters):
  """Optimization parameters for RMSProp with TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
          ...))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      rho: float,
      momentum: float,
      epsilon: float,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for RMS prop.

    Args:
      learning_rate: a floating point value. The learning rate.
      rho: Discounting factor for the history/coming gradient
      momentum: A scalar tensor.
      epsilon: Small value to avoid zero denominator.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details. for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
        Gradient accumulation must be set to true if this is set.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
        Gradient accumulation must be set to true if this is set.
    """
    super(RMSPropParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )
    self.rho = rho
    self.momentum = momentum
    self.epsilon = epsilon


@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
class StochasticGradientDescentParameters(_OptimizationParameters):
  """Optimization parameters for stochastic gradient descent for TPU embeddings.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=(
              tf.tpu.experimental.StochasticGradientDescentParameters(0.1))))
  ```

  """

  def __init__(
      self,
      learning_rate: float,
      use_gradient_accumulation: bool = True,
      clip_weight_min: Optional[float] = None,
      clip_weight_max: Optional[float] = None,
      weight_decay_factor: Optional[float] = None,
      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
      clip_gradient_min: Optional[float] = None,
      clip_gradient_max: Optional[float] = None,
  ):
    """Optimization parameters for stochastic gradient descent.

    Args:
      learning_rate: a floating point value. The learning rate.
      use_gradient_accumulation: setting this to `False` makes embedding
        gradients calculation less accurate but faster. Please see
        `optimization_parameters.proto` for details.
      clip_weight_min: the minimum value to clip by; None means -infinity.
      clip_weight_max: the maximum value to clip by; None means +infinity.
      weight_decay_factor: amount of weight decay to apply; None means that the
        weights are not decayed.
      multiply_weight_decay_factor_by_learning_rate: if true,
        `weight_decay_factor` is multiplied by the current learning rate.
      clip_gradient_min: the minimum value to clip by; None means -infinity.
      clip_gradient_max: the maximum value to clip by; None means +infinity.
    """
    super(StochasticGradientDescentParameters, self).__init__(
        learning_rate=learning_rate,
        use_gradient_accumulation=use_gradient_accumulation,
        clip_weight_min=clip_weight_min,
        clip_weight_max=clip_weight_max,
        weight_decay_factor=weight_decay_factor,
        multiply_weight_decay_factor_by_learning_rate=(
            multiply_weight_decay_factor_by_learning_rate),
        clip_gradient_min=clip_gradient_min,
        clip_gradient_max=clip_gradient_max,
    )


class FrequencyEstimatorParameters(_OptimizationParameters):
  """Optimization parameters for Frequency Estimator TPU embeddings.

  This is a non-standard optimizer, which returns the estimated frequency of
  lookup for the feature passed to it. It should only be used on a table of
  width 1. The gradient fed back to the TPU embedding should always be zero.
  This can be acomplished via using `tf.stop_gradients` on the feature before
  using it.

  You must use the dynamic learning rate mechanism to set the 'learning rate'
  for this table to be the a float32 cast of the global training step counter.

  See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more
  details on this optimizer.

  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
  `optimization_parameters` argument to set the optimizer and its parameters.
  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
  for more details.

  ```
  estimator = tf.estimator.tpu.TPUEstimator(
      ...
      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
          ...
          optimization_parameters=FrequencyEstimatorParameters(0.1),
          ...))
  ```

  """

  def __init__(self, tau: float, max_delta: float, outlier_threshold: float,
               weight_exponent: float):
    """Optimization parameters for frequency estimator.

    Args:
      tau: Learning rate between (0, 1) that is used to update the array.
      max_delta: Maximum value of delta, the difference between the current
        global step and the last global step at which the row was sampled.
      outlier_threshold: Threshold used to determine whether the current update
        is an outlier.
      weight_exponent: The weight exponent used to transform the estimated delta
        into weights.
    """
    super(FrequencyEstimatorParameters, self).__init__(
        learning_rate=1.0,
        use_gradient_accumulation=True,
        clip_weight_min=None,
        clip_weight_max=None,
        weight_decay_factor=None,
        multiply_weight_decay_factor_by_learning_rate=None,
    )
    self.tau = tau
    self.max_delta = max_delta
    self.outlier_threshold = outlier_threshold
    self.weight_exponent = weight_exponent


DeviceConfig = collections.namedtuple('DeviceConfig',
                                      ['num_hosts', 'num_cores', 'job_name'])


class TPUEmbedding(object):
  """API for using TPU for embedding.

    Example:
    ```
    table_config_user = tpu_embedding.TableConfig(
        vocabulary_size=4, dimension=2,
        initializer=initializer, combiner='mean')
    table_to_config_dict = {'video': table_config_video,
                          'user': table_config_user}
    feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
                              'favorited': tpu_embedding.FeatureConfig('video'),
                              'friends': tpu_embedding.FeatureConfig('user')}
    batch_size = 4
    num_hosts = 1
    optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
    mode = tpu_embedding.TRAINING
    embedding = tpu_embedding.TPUEmbedding(
        table_to_config_dict, feature_to_config_dict,
        batch_size, num_hosts, mode, optimization_parameters)

    batch_size_per_core = embedding.batch_size_per_core
    sparse_features_list = []
    for host in hosts:
      with ops.device(host):
        for _ in range(embedding.num_cores_per_host):
          sparse_features = {}
          sparse_features['watched'] = sparse_tensor.SparseTensor(...)
          sparse_features['favorited'] = sparse_tensor.SparseTensor(...)
          sparse_features['friends'] = sparse_tensor.SparseTensor(...)
          sparse_features_list.append(sparse_features)

    enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)
    embedding_variables_and_ops = embedding.create_variables_and_ops()

    def computation():
      activations = embedding.get_activations()
      loss = compute_loss(activations)

      base_optimizer = gradient_descent.GradientDescentOptimizer(
          learning_rate=1)
      cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer(
          base_optimizer)

      train_op = cross_shard_optimizer.minimize(loss)
      gradients = (
          tpu_embedding_gradient.get_gradients_through_compute_gradients(
              cross_shard_optimizer, loss, activations)
      send_gradients_op = embedding.generate_send_gradients_op(gradients)
      with ops.control_dependencies([train_op, send_gradients_op]):
        loss = array_ops.identity(loss)

    loss = tpu.shard(computation,
                     num_shards=embedding.num_cores)

    with self.test_session() as sess:
      sess.run(tpu.initialize_system(embedding_config=
                                     embedding.config_proto))
      sess.run(variables.global_variables_initializer())
      sess.run(embedding_variables_and_ops.load_ops())
      sess.run(enqueue_ops)
      loss_val = sess.run(loss)
    ```

  Example with weight decay:

  >>> def learning_rate_fn(global_step):
  ...   return tf.compat.v1.train.polynomial_decay(
  ...     learning_rate=5e-5,
  ...     global_step=global_step,
  ...     decay_steps=100000,
  ...     end_learning_rate=0.0)
  >>> wordpiece_table_config = TableConfig(
  ...   vocabulary_size=119547,
  ...   dimension=256,
  ...   learning_rate_fn=learning_rate_fn)
  >>> wordpiece_feature_config = FeatureConfig(
  ...   table_id='bert/embeddings/word_embeddings',
  ...   max_sequence_length=512)
  >>> optimization_parameters = AdamParameters(
  ...   learning_rate=5e-5,
  ...   epsilon=1e-6,
  ...   weight_decay_factor=0.01,
  ...   multiply_weight_decay_factor_by_learning_rate=True)
  >>> tpu_embedding = TPUEmbedding(
  ...  table_to_config_dict={
  ...    'bert/embeddings/word_embeddings': wordpiece_table_config,
  ...  },
  ...  feature_to_config_dict={'input_ids': wordpiece_feature_config},
  ...  batch_size=128,
  ...  mode=TRAINING,
  ...  optimization_parameters=optimization_parameters,
  ...  master='')
  >>> with tf.Graph().as_default():
  ...   init_tpu_op = tf.compat.v1.tpu.initialize_system(
  ...     embedding_config=tpu_embedding.config_proto)
  ...   tf.compat.v1.Session().run(init_tpu_op)
  """

  # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
  # the feature should not be used to update embedding table (cr/204852758,
  # cr/204940540). Also, this can support different combiners for different
  # features within the same table.
  # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
  # to `FeatureConfig`?

  # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
  # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
  # respectively?

  # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
  # for-loops around construction of inputs.

  # `optimization_parameter` applies to all tables. If the need arises,
  # we can add `optimization_parameters` to `TableConfig` to override this
  # global setting.
  def __init__(self,
               table_to_config_dict,
               feature_to_config_dict,
               batch_size,
               mode,
               master=None,
               optimization_parameters=None,
               cluster_def=None,
               pipeline_execution_with_tensor_core=False,
               partition_strategy='div',
               profile_data_directory=None,
               device_config=None,
               master_job_name=None):
    """API for using TPU for embedding lookups.

    Args:
      table_to_config_dict: A dictionary mapping from string of table name to
        `TableConfig`. Table refers to an embedding table, e.g. `params`
        argument to `tf.nn.embedding_lookup_sparse()`.
      feature_to_config_dict: A dictionary mapping from string of feature name
        to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
        e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
      batch_size: An `int` representing the global batch size.
      mode: `TRAINING` or `INFERENCE`.
      master: A `string` representing the TensorFlow master to use.
      optimization_parameters: `AdagradParameters`, `AdamParameters`,
        `Stochasticgradientdescentparameters`. Must be set in training unless
        all tables specify their own optimizers. And it must be `None` in
        inference.
      cluster_def: A ClusterDef object describing the TPU cluster.
      pipeline_execution_with_tensor_core: setting this to `True` makes training
        faster, but trained model will be different if step N and step N+1
        involve the same set of embedding IDs. Please see
        `tpu_embedding_configuration.proto` for details.
      partition_strategy: A string, either 'mod' or 'div', specifying how to map
        the lookup id to the embedding tensor. For more information see
        `tf.nn.embedding_lookup_sparse`.
      profile_data_directory: Directory where embedding lookup statistics are
        stored. These statistics summarize information about the inputs to the
        embedding lookup operation, in particular, the average number of
        embedding IDs per example and how well the embedding IDs are load
        balanced across the system. The lookup statistics are used during TPU
        initialization for embedding table partitioning. Collection of lookup
        statistics is done at runtime by  profiling the embedding inputs, only a
        small fraction of input samples are profiled to minimize host CPU
        overhead. Once a suitable number of samples are profiled, the lookup
        statistics are saved to table-specific files in the profile data
        directory generally at the end of a TPU training loop. The filename
        corresponding to each table is obtained by hashing table specific
        parameters (e.g., table name and number of features) and global
        configuration parameters (e.g., sharding strategy and task count). The
        same profile data directory can be shared among several models to reuse
        embedding lookup statistics.
      device_config: A DeviceConfig instance, used when `master` and
        `cluster_def` are both `None`.
      master_job_name: if set, overrides the master job name used to schedule
        embedding ops.

    Raises:
      ValueError: if any input is invalid.
    """
    if partition_strategy not in ('div', 'mod'):
      raise ValueError(f'partition_strategy must be "div" or "mod". '
                       f'Received: {partition_strategy}.')
    self._partition_strategy = partition_strategy

    self._profile_data_directory = profile_data_directory

    _validate_table_to_config_dict(table_to_config_dict)
    # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
    self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)

    _validate_feature_to_config_dict(table_to_config_dict,
                                     feature_to_config_dict)
    self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
    self._table_to_features_dict = (
        _create_table_to_features_dict(self._feature_to_config_dict))
    self._combiners = _create_combiners(self._table_to_config_dict,
                                        self._table_to_features_dict)

    self._batch_size = batch_size

    if master is None and cluster_def is None:
      if device_config is None:
        raise ValueError('When master and cluster_def are both None,'
                         'device_config must be set but is not.')
      if device_config.num_cores % device_config.num_hosts:
        raise ValueError('num_hosts ({}) should divide num_cores ({}) '
                         'but does not.'.format(device_config.num_cores,
                                                device_config.num_hosts))
      self._num_hosts = device_config.num_hosts
      self._num_cores = device_config.num_cores
      self._num_cores_per_host = self._num_cores // self._num_hosts
      self._hosts = [
          '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i)
          for i in range(self._num_hosts)
      ]
    else:
      tpu_system_metadata = (
          tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
              master,
              cluster_def=cluster_def))
      if tpu_system_metadata.num_cores == 0:
        raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
                         'TPUs.'.format(master))
      self._num_hosts = tpu_system_metadata.num_hosts
      if master_job_name is None:
        try:
          master_job_name = tpu_system_metadata_lib.master_job(
              master, cluster_def)
        except ValueError as e:
          raise ValueError(str(e) + ' Please specify a master_job_name.')
      self._hosts = []
      for device in tpu_system_metadata.devices:
        if 'device:CPU:' in device.name and (master_job_name is None or
                                             master_job_name in device.name):
          self._hosts.append(device.name)
      self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
      self._num_cores = tpu_system_metadata.num_cores

    _validate_batch_size(self._batch_size, self._num_cores)
    self._batch_size_per_core = self._batch_size // self._num_cores

    # TODO(shizhiw): remove `mode`?
    if mode == TRAINING:
      _validate_optimization_parameters(optimization_parameters,
                                        self._table_to_config_dict)
      self._optimization_parameters = optimization_parameters
    elif mode == INFERENCE:
      if optimization_parameters is not None:
        raise ValueError(f'`optimization_parameters` should be `None` '
                         f'for inference mode. '
                         f'Received: {optimization_parameters}.')
      self._optimization_parameters = (StochasticGradientDescentParameters(1.))
    else:
      raise ValueError('`mode` only supports {} and {}; got {}.'.format(
          TRAINING, INFERENCE, mode))
    self._mode = mode

    # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
    # and create special handler for inference that inherits from
    # StochasticGradientDescentHandler with more user-friendly error message
    # on get_slot().
    self._optimizer_handler_dict = self._get_optimizer_handler_by_table()

    self._pipeline_execution_with_tensor_core = (
        pipeline_execution_with_tensor_core)
    self._learning_rate_fn = list(
        set(c.learning_rate_fn
            for c in self._table_to_config_dict.values()
            if c.learning_rate_fn is not None))
    self._learning_rate_fn_to_tag = {
        fn: id for id, fn in enumerate(self._learning_rate_fn)
    }

    self._config_proto = self._create_config_proto()

  @property
  def hosts(self):
    """A list of device names for CPU hosts.

    Returns:
      A list of device names for CPU hosts.
    """
    return copy.copy(self._hosts)

  # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and
  # to be consistent with `tpu_embedding_configuration.proto`.
  @property
  def num_cores_per_host(self):
    """Number of TPU cores on a CPU host.

    Returns:
      Number of TPU cores on a CPU host.
    """
    return self._num_cores_per_host

  @property
  def num_cores(self):
    """Total number of TPU cores on all hosts.

    Returns:
      Total number of TPU cores on all hosts.
    """
    return self._num_cores

  @property
  def batch_size_per_core(self):
    """Batch size for each TPU core.

    The sparse tensors in `sparse_features_list` to `generate_enqueue_ops`
       must have batch dimension equal to this.

    Returns:
      Batch size for each TPU core.
    """
    return self._batch_size_per_core

  @property
  def config_proto(self):
    """Create embedding config proto for `tpu.initialize_system()`.

    Returns:
      an `TPUEmbeddingConfiguration` proto describing the desired
         configuration of the hardware embedding lookup tables, which
         is passed to `tpu.initialize_system()`.
    """
    return self._config_proto

  @property
  def table_to_config_dict(self):
    return copy.copy(self._table_to_config_dict)

  @property
  def feature_to_config_dict(self):
    return copy.copy(self._feature_to_config_dict)

  @property
  def table_to_features_dict(self):
    return copy.copy(self._table_to_features_dict)

  @property
  def optimization_parameters(self):
    return self._optimization_parameters

  def _create_config_proto(self):
    """Create `TPUEmbeddingConfiguration`."""
    config_proto = elc.TPUEmbeddingConfiguration()
    for table in self._table_to_config_dict:
      table_descriptor = config_proto.table_descriptor.add()
      table_descriptor.name = table

      table_config = self._table_to_config_dict[table]
      # For small tables, we pad to the number of hosts so that at least one
      # id will be assigned to each host.
      table_descriptor.vocabulary_size = max(table_config.vocabulary_size,
                                             len(self.hosts))
      table_descriptor.dimension = table_config.dimension

      optimization_parameters = (
          self._optimizer_handler_dict[table].get_optimization_parameters())

      parameters = table_descriptor.optimization_parameters
      if table_config.learning_rate:
        parameters.learning_rate.constant = table_config.learning_rate
      elif table_config.learning_rate_fn:
        parameters.learning_rate.dynamic.tag = (
            self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
      else:
        parameters.learning_rate.constant = (
            optimization_parameters.learning_rate)
      parameters.gradient_accumulation_status = (
          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
          if optimization_parameters.use_gradient_accumulation else
          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)

      if optimization_parameters.clip_gradient_min is not None:
        parameters.gradient_clipping_limits.lower.value = (
            optimization_parameters.clip_gradient_min)
      if optimization_parameters.clip_gradient_max is not None:
        parameters.gradient_clipping_limits.upper.value = (
            optimization_parameters.clip_gradient_max)

      if optimization_parameters.clip_weight_min is not None:
        parameters.clipping_limits.lower.value = (
            optimization_parameters.clip_weight_min)
      if optimization_parameters.clip_weight_max is not None:
        parameters.clipping_limits.upper.value = (
            optimization_parameters.clip_weight_max)
      if optimization_parameters.weight_decay_factor:
        parameters.weight_decay_factor = (
            optimization_parameters.weight_decay_factor)
        if (optimization_parameters
            .multiply_weight_decay_factor_by_learning_rate):
          parameters.multiply_weight_decay_factor_by_learning_rate = True
      if table_config.hot_id_replication:
        parameters.hot_id_replication_configuration.status = (
            optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
      optimizer_handler = self._optimizer_handler_dict[table]
      optimizer_handler.set_optimization_parameters(table_descriptor)

    table_to_id = {
        table: i for i, table in enumerate(self._table_to_config_dict)
    }

    # Set feature descriptor field in the config proto.
    for table in self._table_to_features_dict:
      features = self._table_to_features_dict[table]
      for feature in features:
        feature_descriptor = config_proto.feature_descriptor.add()

        feature_descriptor.table_id = table_to_id[
            self._feature_to_config_dict[feature].table_id]
        if self._feature_to_config_dict[feature].max_sequence_length > 0:
          feature_descriptor.input_shape.extend([
              self._batch_size_per_core,
              self._feature_to_config_dict[feature].max_sequence_length
          ])
        else:
          feature_descriptor.input_shape.extend([self._batch_size_per_core])

    config_proto.mode = self._mode
    config_proto.num_hosts = self._num_hosts
    config_proto.num_tensor_cores = self._num_cores
    config_proto.sharding_strategy = (
        elc.TPUEmbeddingConfiguration.DIV_DEFAULT if self._partition_strategy
        == 'div' else elc.TPUEmbeddingConfiguration.MOD)
    config_proto.pipeline_execution_with_tensor_core = (
        self._pipeline_execution_with_tensor_core)
    if self._profile_data_directory:
      config_proto.profile_data_directory = self._profile_data_directory

    return config_proto

  def create_variables_and_ops(self,
                               embedding_variable_name_by_table=None,
                               slot_variable_names_by_table=None):
    """Create embedding and slot variables, with ops to load and retrieve them.

    N.B.: the retrieve embedding variables (including slot variables) ops are
    returned as lambda fn, as the call side might want to impose control
    dependencies between the TPU computation and retrieving actions. For
    example, the following code snippet ensures the TPU computation finishes
    first, and then we pull the variables back from TPU to CPU.

    ```
    updates_ops = []
    with ops.control_dependencies([loss]):
      for op_fn in retrieve_parameters_op_fns:
        update_ops.append(op_fn())
    ```

    Args:
      embedding_variable_name_by_table: A dictionary mapping from string of
        table name to string of embedding variable name. If `None`, defaults
        from `get_default_slot_variable_names()` will be used.
      slot_variable_names_by_table: A dictionary mapping from string of table
        name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
        `None`, defaults from `get_default_slot_variable_names()` will be used.

    Returns:
      `tpu_embedding.VariablesAndOps` with:
        A dictionary mapping from string of table name to embedding variables,
        A dictionary mapping from string of table name to AdagradSlotVariables,
         AdamSlotVariables etc with slot variables,
        A function which returns a list of ops to load embedding and slot
         variables from CPU to TPU.
        A function which returns a list of ops to retrieve embedding and slot
         variables from TPU to CPU.
    """
    embedding_variables_by_table = {}
    slot_variables_by_table = {}
    load_op_fns = []
    retrieve_op_fns = []

    for i, table in enumerate(self._table_to_config_dict):
      if embedding_variable_name_by_table:
        embedding_variable_name = embedding_variable_name_by_table[table]
      else:
        embedding_variable_name = table
      if slot_variable_names_by_table:
        slot_variable_names = slot_variable_names_by_table[table]
      else:
        optimizer_handler = self._optimizer_handler_dict[table]
        slot_variable_names = (
            optimizer_handler.get_default_slot_variable_names(table))

      # TODO(b/139144091): Multi-host support for mid-level API in
      #  eager context (TF 2.0)
      # Workaround below allows single-host use case in TF 2.0
      if context.executing_eagerly():
        device = ''
      else:
        device = _create_device_fn(self._hosts)

      with ops.device(device):
        table_variables = _create_partitioned_variables(
            name=embedding_variable_name,
            num_hosts=self._num_hosts,
            vocabulary_size=self._table_to_config_dict[table].vocabulary_size,
            embedding_dimension=self._table_to_config_dict[table].dimension,
            initializer=self._table_to_config_dict[table].initializer,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        embedding_variables_by_table[table] = table_variables

        # Only loads embedding config to load/retrieve nodes for the first table
        # on the first host, other nodes would use config from the first node.
        config = None if i else self.config_proto.SerializeToString()
        slot_variables_for_table, load_ops_fn, retrieve_ops_fn = (
            self._optimizer_handler_dict[table].create_variables_and_ops(
                table, slot_variable_names, self._num_hosts,
                self._table_to_config_dict[table], table_variables, config))
        slot_variables_by_table[table] = slot_variables_for_table
        load_op_fns.append(load_ops_fn)
        retrieve_op_fns.append(retrieve_ops_fn)

    def load_ops():
      """Calls and returns the load ops for each embedding table.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_ops_list = []
      for load_op_fn in load_op_fns:
        load_ops_list.extend(load_op_fn())
      return load_ops_list

    def retrieve_ops():
      """Calls and returns the retrieve ops for each embedding table.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_ops_list = []
      for retrieve_op_fn in retrieve_op_fns:
        retrieve_ops_list.extend(retrieve_op_fn())
      return retrieve_ops_list

    return VariablesAndOps(embedding_variables_by_table,
                           slot_variables_by_table, load_ops, retrieve_ops)

  def generate_enqueue_ops(
      self,
      enqueue_datas_list,
      mode_override=None,
      ragged=False,
  ):
    """Generate enqueue ops.

    Args:
      enqueue_datas_list: a list of dictionary mapping from string of feature
        names to EnqueueData. Each dictionary is for one TPU core. Dictionaries
        for the same host should be contiguous in the list.
      mode_override: A string input that overrides the mode specified in the
        TPUEmbeddingConfiguration. Supported values are {'unspecified',
        'inference', 'training', 'backward_pass_only'}. When set to
        'unspecified', the mode set in TPUEmbeddingConfiguration is used,
        otherwise mode_override is used (optional).
      ragged: If True, creates RaggedTensor enqueue ops rather than
        SparseTensor.

    Returns:
      Ops to enqueue to TPU for embedding.
    """
    self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
    return [
        self._generate_enqueue_op(  # pylint: disable=g-complex-comprehension
            enqueue_datas,
            device_ordinal=i % self._num_cores_per_host,
            mode_override=mode_override,
            ragged=ragged,
        ) for i, enqueue_datas in enumerate(enqueue_datas_list)
    ]

  def _validate_generate_enqueue_ops_enqueue_datas_list(self,
                                                        enqueue_datas_list):
    """Validate `enqueue_datas_list`."""

    def _check_agreement(data, name, feature, enqueue_data):
      """Helper function to check device agreement."""
      if (data is not None and
          data.device != enqueue_data.embedding_indices.device):
        raise ValueError('Device of {0} does not agree with that of'
                         'embedding_indices for feature {1}.'.format(
                             name, feature))

    feature_set = set(self._feature_to_config_dict.keys())
    contiguous_device = None
    for i, enqueue_datas in enumerate(enqueue_datas_list):
      used_feature_set = set(enqueue_datas.keys())

      # Check features are valid.
      missing_feature_set = feature_set - used_feature_set
      if missing_feature_set:
        raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
                         'in `feature_to_config_dict`: {}.'.format(
                             i, missing_feature_set))

      extra_feature_set = used_feature_set - feature_set
      if extra_feature_set:
        raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
                         'in `feature_to_config_dict`: {}.'.format(
                             i, extra_feature_set))

      device = None
      device_feature = None
      for feature, enqueue_data in six.iteritems(enqueue_datas):
        combiner = self._table_to_config_dict[
            self._feature_to_config_dict[feature].table_id].combiner

        if isinstance(enqueue_data, EnqueueData):
          if enqueue_data.sample_indices is None and combiner:
            logging.warn(
                'No sample indices set for features %f table %f but '
                'combiner is set to %s.', feature,
                self._feature_to_config_dict[feature].table_id, combiner)
          _check_agreement(enqueue_data.sample_indices, 'sample_indices',
                           feature, enqueue_data)
          _check_agreement(enqueue_data.aggregation_weights,
                           'aggregation_weights', feature, enqueue_data)

        elif isinstance(enqueue_data, RaggedEnqueueData):
          if enqueue_data.row_splits is None and combiner:
            logging.warn(
                'No row splits set for features %f table %f but '
                'combiner is set to %s.', feature,
                self._feature_to_config_dict[feature].table_id, combiner)
          _check_agreement(enqueue_data.row_splits, 'row_splits', feature,
                           enqueue_data)
          _check_agreement(enqueue_data.aggregation_weights,
                           'aggregation_weights', feature, enqueue_data)
        else:
          raise ValueError(
              '`enqueue_datas_list[{}]` has a feature that is not mapped to '
              '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
                  i, feature))
        # Check all features are on the same device.
        if device is None:
          device = enqueue_data.embedding_indices.device
          device_feature = feature
        else:
          if device != enqueue_data.embedding_indices.device:
            raise ValueError('Devices are different between features in '
                             '`enqueue_datas_list[{}]`; '
                             'devices: {}, {}; features: {}, {}.'.format(
                                 i, device,
                                 enqueue_data.embedding_indices.device, feature,
                                 device_feature))

      if i % self._num_cores_per_host:
        if device != contiguous_device:
          raise ValueError('We expect the `enqueue_datas` which are on the '
                           'same host to be contiguous in '
                           '`enqueue_datas_list`, '
                           '`enqueue_datas_list[{}]` is on device {}, '
                           'but is expected to be on device {}.'.format(
                               i, device, contiguous_device))
      else:
        contiguous_device = device

  def _generate_enqueue_op(self,
                           enqueue_datas,
                           device_ordinal,
                           mode_override=None,
                           ragged=False):
    """Creates op for enqueuing batch to TPU."""
    enqueue_data0 = list(enqueue_datas.values())[0]
    with ops.colocate_with(enqueue_data0.embedding_indices):
      return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
          device_ordinal=device_ordinal,
          combiners=self._combiners,
          mode_override=mode_override,
          **self._format_for_tpu_embedding_arbitrary_tensor_batch(
              enqueue_datas, ragged))

  def _format_for_tpu_embedding_arbitrary_tensor_batch(self, enqueue_datas,
                                                       ragged):
    """Format features for `enqueue_tpu_embedding_arbitrary_tensor_batch()`.

    Args:
      enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
      ragged: If True, extract row splits from the data rather than sample
        indices.

    Returns:
      Dict of arguments for `enqueue_tpu_embedding_arbitrary_tensor_batch()`.
    """

    kwargs = {
        'sample_indices_or_row_splits': [],
        'embedding_indices': [],
        'aggregation_weights': [],
    }
    int_zeros = array_ops.zeros((0,), dtype=dtypes.int64)
    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
    for table in self._table_to_features_dict:
      features = self._table_to_features_dict[table]
      for feature in features:
        enqueue_data = enqueue_datas[feature]
        if ragged:
          kwargs['sample_indices_or_row_splits'].append(
              enqueue_data.row_splits if enqueue_data
              .row_splits is not None else int_zeros)
        else:
          if (self._feature_to_config_dict[feature].max_sequence_length > 0 and
              enqueue_data.sample_indices is not None and
              enqueue_data.sample_indices.shape[1] == 2):
            # Pad the sample indices as if the enqueued sparse tensor is rank 2.
            sample_indices = array_ops.pad(
                enqueue_data.sample_indices, paddings=[[0, 0], [0, 1]])
            kwargs['sample_indices_or_row_splits'].append(sample_indices)
          else:
            # If the sample_indices is rank 1 or not present, treat it as dense
            # tensor.
            if (enqueue_data.sample_indices is None or
                enqueue_data.sample_indices.shape[1] == 1):
              kwargs['sample_indices_or_row_splits'].append(int_zeros)
            else:
              kwargs['sample_indices_or_row_splits'].append(
                  enqueue_data.sample_indices)

        kwargs['aggregation_weights'].append(
            enqueue_data.aggregation_weights if enqueue_data
            .aggregation_weights is not None else float_zeros)

        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
    return kwargs

  def get_activations(self):
    """Get activations for features.

    This should be called within `computation` that is passed to
      `tpu.replicate` and friends.

    Returns:
      A dictionary mapping from `String` of feature name to `Tensor`
        of activation.
    """
    recv_activations = tpu_ops.recv_tpu_embedding_activations(
        num_outputs=len(self._feature_to_config_dict),
        config=self._config_proto.SerializeToString())

    activations = collections.OrderedDict()
    index = 0
    for table in self._table_to_features_dict:
      for feature in self._table_to_features_dict[table]:
        activations[feature] = recv_activations[index]
        index += 1
    return activations

  def generate_send_gradients_op(self, feature_to_gradient_dict, step=None):
    """Send gradient to TPU embedding.

    Args:
      feature_to_gradient_dict: dict mapping feature names to gradient wrt
        activations.
      step: the current global step, used for dynamic learning rate.

    Returns:
      SendTPUEmbeddingGradients Op.

    Raises:
      RuntimeError: If `mode` is not `TRAINING`.
    """
    if self._mode != TRAINING:
      raise RuntimeError('Only in training mode gradients need to '
                         'be sent to TPU embedding; got mode {}.'.format(
                             self._mode))
    if step is None and self._learning_rate_fn:
      raise ValueError('There are dynamic learning rates but step is None.')

    gradients = []
    for table in self._table_to_features_dict:
      for feature in self._table_to_features_dict[table]:
        gradients.append(feature_to_gradient_dict[feature])

    return tpu_ops.send_tpu_embedding_gradients(
        inputs=gradients,
        learning_rates=[
            math_ops.cast(fn(step), dtype=dtypes.float32)
            for fn in self._learning_rate_fn
        ],
        config=self.config_proto.SerializeToString())

  def _get_optimizer_handler_by_table(self):
    optimizer_handlers = {}
    for table, table_config in self.table_to_config_dict.items():
      if table_config.optimization_parameters is not None:
        optimizer = table_config.optimization_parameters
      else:
        optimizer = self._optimization_parameters
      optimizer_handlers[table] = _get_optimization_handler(optimizer)

    return optimizer_handlers


def _validate_table_to_config_dict(table_to_config_dict):
  """Validate `table_to_config_dict`."""
  for k, v in six.iteritems(table_to_config_dict):
    if not isinstance(v, TableConfig):
      raise ValueError('Value of `table_to_config_dict` must be of type '
                       '`TableConfig`, got {} for {}.'.format(type(v), k))


def _validate_feature_to_config_dict(table_to_config_dict,
                                     feature_to_config_dict):
  """Validate `feature_to_config_dict`."""
  used_table_set = set(
      [feature.table_id for feature in feature_to_config_dict.values()])
  table_set = set(table_to_config_dict.keys())

  unused_table_set = table_set - used_table_set
  if unused_table_set:
    raise ValueError(
        '`table_to_config_dict` specifies table that is not '
        'used in `feature_to_config_dict`: {}.'.format(unused_table_set))

  extra_table_set = used_table_set - table_set
  if extra_table_set:
    raise ValueError(
        '`feature_to_config_dict` refers to a table that is not '
        'specified in `table_to_config_dict`: {}.'.format(extra_table_set))


def _validate_batch_size(batch_size, num_cores):
  if batch_size % num_cores:
    raise ValueError('`batch_size` is not a multiple of number of '
                     'cores. `batch_size`={}, `_num_cores`={}.'.format(
                         batch_size, num_cores))


def _validate_optimization_parameters(optimization_parameters,
                                      table_to_config_dict):
  """Validate global optimization_parameters and per table optimizers.

  If global optimizer is `None`, all table optimizers should be non `None`.

  Args:
      optimization_parameters: global optimizer provided in `TPUEmbedding`
        constructor.
      table_to_config_dict: A dictionary mapping from string of table name to
        `TableConfig`.
  """
  tbl_optimizer_missing = False
  for _, table_config in table_to_config_dict.items():
    if table_config.optimization_parameters is None:
      tbl_optimizer_missing = True
      break

  if optimization_parameters:
    if not isinstance(optimization_parameters, _OptimizationParameters):
      raise ValueError('`optimization_parameters` must inherit from '
                       '`_OptimizationParameters`. '
                       '`type(optimization_parameters)`={}'.format(
                           type(optimization_parameters)))
  else:
    # Missing global optimization_parameters.
    if tbl_optimizer_missing:
      raise ValueError('`optimization_parameters` is missing.')


class _OptimizerHandler(object):
  """Interface class for handling optimizer specific logic."""

  def __init__(self, optimization_parameters):
    self._optimization_parameters = optimization_parameters

  def get_optimization_parameters(self):
    return self._optimization_parameters

  def set_optimization_parameters(self, table_descriptor):
    raise NotImplementedError()

  def get_default_slot_variable_names(self, table):
    raise NotImplementedError()

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    raise NotImplementedError()


class _AdagradHandler(_OptimizerHandler):
  """Handles Adagrad specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.adagrad.SetInParent()

  def get_default_slot_variable_names(self, table):
    return AdagradSlotVariableNames('{}/{}'.format(table, 'Adagrad'))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    accumulator_initializer = init_ops.constant_initializer(
        self._optimization_parameters.initial_accumulator)
    accumulator_variables = _create_partitioned_variables(
        name=slot_variable_names.accumulator,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=accumulator_initializer)
    slot_variables = AdagradSlotVariables(accumulator_variables)

    def load_ops_fn():
      """Returns the retrieve ops for AdaGrad embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      config = config_proto
      load_op_list = []
      for host_id, table_variable, accumulator_variable in zip(
          range(num_hosts), table_variables, accumulator_variables):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_adagrad_parameters(
                  parameters=table_variable,
                  accumulators=accumulator_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for AdaGrad embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      config = config_proto
      retrieve_op_list = []
      for host_id, table_variable, accumulator_variable in (zip(
          range(num_hosts), table_variables, accumulator_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_accumulator = (
              tpu_ops.retrieve_tpu_embedding_adagrad_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(accumulator_variable, retrieved_accumulator))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _AdagradMomentumHandler(_OptimizerHandler):
  """Handles Adagrad with Momentum specific logic.

  Creates slot variables and defines their initializers. Defines load/retrieve
  operations to be used for loading variables into TPU memory (from host memory)
  and retrieving variables from TPU memory (into host memory).
  """

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.adagrad_momentum.SetInParent()
    table_descriptor.optimization_parameters.adagrad_momentum.momentum = (
        self._optimization_parameters.momentum)
    table_descriptor.optimization_parameters.adagrad_momentum.use_nesterov = (
        self._optimization_parameters.use_nesterov)
    table_descriptor.optimization_parameters.adagrad_momentum.exponent = (
        self._optimization_parameters.exponent)
    table_descriptor.optimization_parameters.adagrad_momentum.beta2 = (
        self._optimization_parameters.beta2)
    table_descriptor.optimization_parameters.adagrad_momentum.epsilon = (
        self._optimization_parameters.epsilon)

  def get_default_slot_variable_names(self, table):
    return AdagradMomentumSlotVariableNames(
        '{}/{}/Accumulator'.format(table, 'AdagradMomentum'),
        '{}/{}/Momentum'.format(table, 'AdagradMomentum'))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    accumulator_initializer = init_ops.zeros_initializer()
    accumulator_variables = _create_partitioned_variables(
        name=slot_variable_names.accumulator,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=accumulator_initializer)
    momenta_initializer = init_ops.zeros_initializer()
    momenta_variables = _create_partitioned_variables(
        name=slot_variable_names.momenta,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=momenta_initializer)
    slot_variables = AdagradMomentumSlotVariables(accumulator_variables,
                                                  momenta_variables)

    def load_ops_fn():
      """Returns the load ops for AdaGrad with momentum embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      config = config_proto
      load_op_list = []
      for host_id, table_variable, accumulator_variable, momenta_variable in zip(
          range(num_hosts), table_variables, accumulator_variables,
          momenta_variables):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_adagrad_momentum_parameters(
                  parameters=table_variable,
                  accumulators=accumulator_variable,
                  momenta=momenta_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for AdaGrad with momentum embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      config = config_proto
      retrieve_op_list = []
      for host_id, table_variable, accumulator_variable, momenta_variable in (
          zip(
              range(num_hosts), table_variables, accumulator_variables,
              momenta_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_accumulator, retrieved_momenta = (
              tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(accumulator_variable, retrieved_accumulator),
              state_ops.assign(momenta_variable, retrieved_momenta))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _ProximalAdagradHandler(_OptimizerHandler):
  """Handles ProximalAdagrad specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.proximal_adagrad.SetInParent()
    table_descriptor.optimization_parameters.proximal_adagrad.l1 = (
        self._optimization_parameters.l1_regularization_strength)
    table_descriptor.optimization_parameters.proximal_adagrad.l2 = (
        self._optimization_parameters.l2_regularization_strength)

  def get_default_slot_variable_names(self, table):
    return ProximalAdagradSlotVariableNames('{}/{}'.format(
        table, 'ProximalAdagrad'))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    accumulator_initializer = init_ops.constant_initializer(
        self._optimization_parameters.initial_accumulator)
    accumulator_variables = _create_partitioned_variables(
        name=slot_variable_names.accumulator,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=accumulator_initializer)
    slot_variables = ProximalAdagradSlotVariables(accumulator_variables)

    def load_ops_fn():
      """Returns the retrieve ops for Proximal AdaGrad embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      config = config_proto
      load_op_list = []
      for host_id, table_variable, accumulator_variable in zip(
          range(num_hosts), table_variables, accumulator_variables):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_proximal_adagrad_parameters(
                  parameters=table_variable,
                  accumulators=accumulator_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for Proximal AdaGrad embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      config = config_proto
      retrieve_op_list = []
      for host_id, table_variable, accumulator_variable in (zip(
          range(num_hosts), table_variables, accumulator_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_accumulator = (
              tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(accumulator_variable, retrieved_accumulator))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _AdamHandler(_OptimizerHandler):
  """Handles Adam specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.adam.beta1 = (
        self._optimization_parameters.beta1)
    table_descriptor.optimization_parameters.adam.beta2 = (
        self._optimization_parameters.beta2)
    table_descriptor.optimization_parameters.adam.epsilon = (
        self._optimization_parameters.epsilon)
    table_descriptor.optimization_parameters.adam.use_non_lazy_adam = (
        not self._optimization_parameters.lazy_adam)
    table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = (
        self._optimization_parameters.sum_inside_sqrt)

  def get_default_slot_variable_names(self, table):
    return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'),
                                 '{}/{}/v'.format(table, 'Adam'))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    m_initializer = init_ops.zeros_initializer()
    m_variables = _create_partitioned_variables(
        name=slot_variable_names.m,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=m_initializer)
    v_initializer = init_ops.zeros_initializer()
    v_variables = _create_partitioned_variables(
        name=slot_variable_names.v,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=v_initializer)
    slot_variables = AdamSlotVariables(m_variables, v_variables)

    def load_ops_fn():
      """Returns the retrieve ops for AdaGrad embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      config = config_proto
      for host_id, table_variable, m_variable, v_variable in (zip(
          range(num_hosts), table_variables, m_variables, v_variables)):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_adam_parameters(
                  parameters=table_variable,
                  momenta=m_variable,
                  velocities=v_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        # Set config to None to enforce that config is only loaded to the first
        # table.
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for Adam embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_op_list = []
      config = config_proto
      for host_id, table_variable, m_variable, v_variable in (zip(
          range(num_hosts), table_variables, m_variables, v_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_m, retrieved_v = (
              tpu_ops.retrieve_tpu_embedding_adam_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(m_variable, retrieved_m),
              state_ops.assign(v_variable, retrieved_v))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _FtrlHandler(_OptimizerHandler):
  """Handles Ftrl specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.ftrl.lr_power = (
        self._optimization_parameters.learning_rate_power)
    table_descriptor.optimization_parameters.ftrl.l1 = (
        self._optimization_parameters.l1_regularization_strength)
    table_descriptor.optimization_parameters.ftrl.l2 = (
        self._optimization_parameters.l2_regularization_strength)
    table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = (
        self._optimization_parameters.multiply_linear_by_learning_rate)
    table_descriptor.optimization_parameters.ftrl.beta = (
        self._optimization_parameters.beta)
    table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = (
        self._optimization_parameters.allow_zero_accumulator)

  def get_default_slot_variable_names(self, table):
    # These match the default slot variable names created by
    # tf.train.FtrlOptimizer.
    return FtrlSlotVariableNames(
        '{}/{}'.format(table, 'Ftrl'),  # accumulator
        '{}/{}'.format(table, 'Ftrl_1'))  # linear

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    accumulator_initializer = init_ops.constant_initializer(
        self._optimization_parameters.initial_accumulator_value)
    accumulator_variables = _create_partitioned_variables(
        name=slot_variable_names.accumulator,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=accumulator_initializer)
    linear_initializer = init_ops.constant_initializer(
        self._optimization_parameters.initial_linear_value)
    linear_variables = _create_partitioned_variables(
        name=slot_variable_names.linear,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=linear_initializer)
    slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables)

    def load_ops_fn():
      """Returns the retrieve ops for Ftrl embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      config = config_proto
      load_op_list = []
      for host_id, table_variable, accumulator_variable, linear_variable in zip(
          range(num_hosts), table_variables, accumulator_variables,
          linear_variables):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_ftrl_parameters(
                  parameters=table_variable,
                  accumulators=accumulator_variable,
                  linears=linear_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for Ftrl embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      config = config_proto
      retrieve_op_list = []
      for host_id, table_variable, accumulator_variable, linear_variable in zip(
          range(num_hosts), table_variables, accumulator_variables,
          linear_variables):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_accumulator, retrieved_linear = (
              tpu_ops.retrieve_tpu_embedding_ftrl_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(accumulator_variable, retrieved_accumulator),
              state_ops.assign(linear_variable, retrieved_linear))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _ProximalYogiHandler(_OptimizerHandler):
  """Handles Proximal Yogi specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.proximal_yogi.SetInParent()
    table_descriptor.optimization_parameters.proximal_yogi.beta1 = (
        self._optimization_parameters.beta1)
    table_descriptor.optimization_parameters.proximal_yogi.beta2 = (
        self._optimization_parameters.beta2)
    table_descriptor.optimization_parameters.proximal_yogi.epsilon = (
        self._optimization_parameters.epsilon)
    table_descriptor.optimization_parameters.proximal_yogi.l1 = (
        self._optimization_parameters.l1_regularization_strength)
    table_descriptor.optimization_parameters.proximal_yogi.l2 = (
        self._optimization_parameters.l2_regularization_strength)

  def get_default_slot_variable_names(self, table):
    return ProximalYogiSlotVariableNames(
        '{}/{}'.format(table, 'ProximalYogi'),  # v
        '{}/{}_1'.format(table, 'ProximalYogi'))  # m

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    v_initializer = init_ops.constant_initializer(
        self._optimization_parameters.initial_accumulator_value)
    v_variables = _create_partitioned_variables(
        name=slot_variable_names.v,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=v_initializer)
    m_initializer = init_ops.zeros_initializer()
    m_variables = _create_partitioned_variables(
        name=slot_variable_names.m,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=m_initializer)
    slot_variables = ProximalYogiSlotVariables(v_variables, m_variables)

    def load_ops_fn():
      """Returns the load ops for Proximal Yogi embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      config = config_proto
      for host_id, table_variable, v_variable, m_variable in (zip(
          range(num_hosts), table_variables, v_variables, m_variables)):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_proximal_yogi_parameters(
                  parameters=table_variable,
                  v=v_variable,
                  m=m_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        # Set config to None to enforce that config is only loaded to the first
        # table.
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for Proximal Yogi embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_op_list = []
      config = config_proto
      for host_id, table_variable, v_variable, m_variable in (zip(
          range(num_hosts), table_variables, v_variables, m_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_v, retrieved_m = (
              tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(v_variable, retrieved_v),
              state_ops.assign(m_variable, retrieved_m))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _MomentumHandler(_OptimizerHandler):
  """Handles Momentum specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    (table_descriptor.optimization_parameters.momentum.SetInParent())
    table_descriptor.optimization_parameters.momentum.momentum = (
        self._optimization_parameters.momentum)
    table_descriptor.optimization_parameters.momentum.use_nesterov = (
        self._optimization_parameters.use_nesterov)

  def get_default_slot_variable_names(self, table):
    return MomentumSlotVariableNames('{}/{}'.format(table, 'Momentum'))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):

    momenta_initializer = init_ops.zeros_initializer()
    momenta_variables = _create_partitioned_variables(
        name=slot_variable_names.momenta,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=momenta_initializer)
    slot_variables = MomentumSlotVariables(momenta_variables)

    def load_ops_fn():
      """Returns the retrieve ops for Momentum embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      config = config_proto
      for host_id, table_variable, momenta_variable in (zip(
          range(num_hosts), table_variables, momenta_variables)):
        with ops.colocate_with(table_variable):
          load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters(
              parameters=table_variable,
              momenta=momenta_variable,
              table_name=table,
              num_shards=num_hosts,
              shard_id=host_id,
              config=config,
          )
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for Momentum embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_op_list = []
      config = config_proto
      for host_id, table_variable, momenta_variable in (zip(
          range(num_hosts), table_variables, momenta_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_momenta = (
              tpu_ops.retrieve_tpu_embedding_momentum_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config,
              ))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(momenta_variable, retrieved_momenta))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _RMSPropHandler(_OptimizerHandler):
  """Handles RMS prop specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    (table_descriptor.optimization_parameters.rms_prop.SetInParent())
    table_descriptor.optimization_parameters.rms_prop.rho = (
        self._optimization_parameters.rho)
    table_descriptor.optimization_parameters.rms_prop.epsilon = (
        self._optimization_parameters.epsilon)
    table_descriptor.optimization_parameters.rms_prop.momentum = (
        self._optimization_parameters.momentum)

  def get_default_slot_variable_names(self, table):
    return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'),
                                    '{}/{}/mom'.format(table, 'RMSProp'))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):

    ms_variables = _create_partitioned_variables(
        name=slot_variable_names.ms,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=init_ops.zeros_initializer(),
    )
    mom_variables = _create_partitioned_variables(
        name=slot_variable_names.mom,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=init_ops.zeros_initializer(),
    )
    slot_variables = RMSPropSlotVariables(ms_variables, mom_variables)

    def load_ops_fn():
      """Returns the retrieve ops for RMS Prop embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      config = config_proto
      for host_id, table_variable, ms_variable, mom_variable in (zip(
          range(num_hosts), table_variables, ms_variables, mom_variables)):
        with ops.colocate_with(table_variable):
          load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters(
              parameters=table_variable,
              ms=ms_variable,
              mom=mom_variable,
              table_name=table,
              num_shards=num_hosts,
              shard_id=host_id,
              config=config,
          )
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for RMS Prop embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_op_list = []
      config = config_proto
      for host_id, table_variable, ms_variable, mom_variable in (zip(
          range(num_hosts), table_variables, ms_variables, mom_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_ms, retrieved_mom = (
              tpu_ops.retrieve_tpu_embedding_rms_prop_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config,
              ))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(ms_variable, retrieved_ms),
              state_ops.assign(mom_variable, retrieved_mom))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _FrequencyEstimatorHandler(_OptimizerHandler):
  """Handles frequency estimator specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    table_descriptor.optimization_parameters.frequency_estimator.SetInParent()
    freq = table_descriptor.optimization_parameters.frequency_estimator
    freq.tau = self._optimization_parameters.tau
    freq.max_delta = self._optimization_parameters.max_delta
    freq.outlier_threshold = self._optimization_parameters.outlier_threshold
    freq.weight_exponent = self._optimization_parameters.weight_exponent

  def get_default_slot_variable_names(self, table):
    return FrequencyEstimatorSlotVariableNames(
        '{}/FrequencyEstimator'.format(table))

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    if table_config.dimension != 1:
      raise ValueError('FrequencyEstimator tables should only have a dimension '
                       'of 1. Received dimension {}'.format(
                           table_config.dimension))

    last_hit_step_variables = _create_partitioned_variables(
        name=slot_variable_names.last_hit_step,
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=init_ops.zeros_initializer(),
    )
    slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables)

    def load_ops_fn():
      """Returns the retrieve ops for Frequency Estimator embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      config = config_proto
      for host_id, table_variable, last_hit_step_variable in (zip(
          range(num_hosts), table_variables, last_hit_step_variables)):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_frequency_estimator_parameters(
                  parameters=table_variable,
                  last_hit_step=last_hit_step_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for Frequency Estimator embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_op_list = []
      config = config_proto
      for host_id, table_variable, last_hit_step_variable in (zip(
          range(num_hosts), table_variables, last_hit_step_variables)):
        with ops.colocate_with(table_variable):
          retrieved_table, retrieved_last_hit_step = (
              tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config,
              ))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table),
              state_ops.assign(last_hit_step_variable, retrieved_last_hit_step))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return slot_variables, load_ops_fn, retrieve_ops_fn


class _StochasticGradientDescentHandler(_OptimizerHandler):
  """Handles stochastic gradient descent specific logic."""

  def set_optimization_parameters(self, table_descriptor):
    (table_descriptor.optimization_parameters.stochastic_gradient_descent
     .SetInParent())

  def get_default_slot_variable_names(self, table):
    return None

  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
                               table_config, table_variables, config_proto):
    del table_config

    def load_ops_fn():
      """Returns the retrieve ops for AdaGrad embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      config = config_proto
      for host_id, table_variable in enumerate(table_variables):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
                  parameters=table_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
        config = None
        load_op_list.append(load_parameters_op)
      return load_op_list

    def retrieve_ops_fn():
      """Returns the retrieve ops for SGD embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """
      retrieve_op_list = []
      config = config_proto
      for host_id, table_variable in enumerate(table_variables):
        with ops.colocate_with(table_variable):
          retrieved_table = (
              tpu_ops
              .retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id,
                  config=config))
          retrieve_parameters_op = control_flow_ops.group(
              state_ops.assign(table_variable, retrieved_table))
        config = None
        retrieve_op_list.append(retrieve_parameters_op)
      return retrieve_op_list

    return None, load_ops_fn, retrieve_ops_fn


def _get_optimization_handler(optimization_parameters):
  """Gets the optimization handler given the parameter type."""
  if isinstance(optimization_parameters, AdagradParameters):
    return _AdagradHandler(optimization_parameters)
  elif isinstance(optimization_parameters, AdagradMomentumParameters):
    return _AdagradMomentumHandler(optimization_parameters)
  elif isinstance(optimization_parameters, ProximalAdagradParameters):
    return _ProximalAdagradHandler(optimization_parameters)
  elif isinstance(optimization_parameters, AdamParameters):
    return _AdamHandler(optimization_parameters)
  elif isinstance(optimization_parameters, FtrlParameters):
    return _FtrlHandler(optimization_parameters)
  elif isinstance(optimization_parameters, ProximalYogiParameters):
    return _ProximalYogiHandler(optimization_parameters)
  elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
    return _StochasticGradientDescentHandler(optimization_parameters)
  elif isinstance(optimization_parameters, MomentumParameters):
    return _MomentumHandler(optimization_parameters)
  elif isinstance(optimization_parameters, RMSPropParameters):
    return _RMSPropHandler(optimization_parameters)
  elif isinstance(optimization_parameters, FrequencyEstimatorParameters):
    return _FrequencyEstimatorHandler(optimization_parameters)
  return NotImplementedError()


def _create_ordered_dict(d):
  """Create an OrderedDict from Dict."""
  return collections.OrderedDict((k, d[k]) for k in sorted(d))


def _create_combiners(table_to_config_dict, table_to_features_dict):
  """Create a per feature list of combiners, ordered by table."""
  combiners = []
  for table in table_to_config_dict:
    combiner = table_to_config_dict[table].combiner or 'sum'
    combiners.extend([combiner] * len(table_to_features_dict[table]))
  return combiners


def _create_table_to_features_dict(feature_to_config_dict):
  """Create mapping from table to a list of its features."""
  table_to_features_dict_tmp = {}
  for feature, feature_config in six.iteritems(feature_to_config_dict):
    if feature_config.table_id in table_to_features_dict_tmp:
      table_to_features_dict_tmp[feature_config.table_id].append(feature)
    else:
      table_to_features_dict_tmp[feature_config.table_id] = [feature]

  table_to_features_dict = collections.OrderedDict()
  for table in sorted(table_to_features_dict_tmp):
    table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
  return table_to_features_dict


def _create_device_fn(hosts):
  """Create device_fn() to use with _create_partitioned_variables()."""

  def device_fn(op):
    """Returns the `device` for `op`."""
    part_match = re.match(r'.*/part_(\d+)(/|$)', op.name)
    dummy_match = re.match(r'.*dummy_(\d+).*', op.name)
    if not part_match and not dummy_match:
      raise RuntimeError(
          'Internal Error: Expected {} to contain /part_* or dummy_*'.format(
              op.name))

    if part_match:
      idx = int(part_match.group(1))
    else:
      idx = int(dummy_match.group(1))  # pytype: disable=attribute-error

    device = hosts[idx]
    logging.debug('assigning {} to {}.', op, device)
    return device

  return device_fn


def _create_partitioned_variables(name,
                                  num_hosts,
                                  vocabulary_size,
                                  embedding_dimension,
                                  initializer,
                                  collections=None):  # pylint: disable=redefined-outer-name
  """Creates PartitionedVariables based on `num_hosts` for `table`."""

  num_slices = min(vocabulary_size, num_hosts)

  var_list = list(
      variable_scope.get_variable(
          name,
          shape=(vocabulary_size, embedding_dimension),
          partitioner=partitioned_variables.fixed_size_partitioner(num_slices),
          dtype=dtypes.float32,
          initializer=initializer,
          collections=collections,
          trainable=False))

  if vocabulary_size >= num_hosts:
    return var_list

  # For padded part, define the dummy variable to be loaded into TPU system.
  for idx in range(num_hosts - vocabulary_size):
    var_list.append(
        variable_scope.get_variable(
            'dummy_{}_{}'.format(vocabulary_size + idx, name),
            shape=(1, embedding_dimension),
            dtype=dtypes.float32,
            initializer=initializer,
            collections=[ops.GraphKeys.LOCAL_VARIABLES],
            trainable=False))

  return var_list
