# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for working with and creating SaveableObjects."""
import functools
import six

from tensorflow.python.checkpoint import saveable_compat
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.trackable import base as trackable
from tensorflow.python.trackable import python_state
from tensorflow.python.trackable import trackable_utils
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export

# Op names which identify variable reads which should be saved.
_VARIABLE_OPS = set(["Variable",
                     "VariableV2",
                     "AutoReloadVariable",
                     "VarHandleOp",
                     "ReadVariableOp"])


def set_cpu0(device_string):
  """Creates a new device string based on `device_string` but using /CPU:0.

  If the device is already on /CPU:0, this is a no-op.

  Args:
    device_string: A device string.

  Returns:
    A device string.
  """
  parsed_device = pydev.DeviceSpec.from_string(device_string)
  parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
  return parsed_device.to_string()


class ReferenceVariableSaveable(saveable_object.SaveableObject):
  """SaveableObject implementation that handles reference variables."""

  def __init__(self, var, slice_spec, name):
    spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
    super(ReferenceVariableSaveable, self).__init__(var, [spec], name)

  def restore(self, restored_tensors, restored_shapes):
    restored_tensor = restored_tensors[0]
    if restored_shapes is not None:
      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
    return state_ops.assign(
        self.op,
        restored_tensor,
        validate_shape=restored_shapes is None and
        self.op.get_shape().is_fully_defined())


class ResourceVariableSaveable(saveable_object.SaveableObject):
  """SaveableObject implementation that handles ResourceVariables."""

  def __init__(self, var, slice_spec, name):
    self._var_device = var.device
    self._var_shape = var.shape
    if isinstance(var, ops.Tensor):
      self.handle_op = var.op.inputs[0]
      tensor = var
    elif resource_variable_ops.is_resource_variable(var):

      def _read_variable_closure(v):
        def f():
          with ops.device(v.device):
            if context.executing_eagerly() and not v.is_initialized():
              # A SaveSpec tensor value of `None` indicates that the variable is
              # uninitialized.
              return None
            # Read the variable without making a copy to limit memory usage.
            x = v.read_value_no_copy()
            # To allow variables placed on non-CPU devices to be checkpointed,
            # we copy them to CPU on the same machine first.
            with ops.device("/device:CPU:0"):
              return array_ops.identity(x)

        return f

      self.handle_op = var.handle
      tensor = _read_variable_closure(var)
    else:
      raise ValueError(
          "Saveable is neither a resource variable nor a read operation."
          f" Got: {repr(var)}")
    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
                                    dtype=var.dtype, device=var.device)
    super(ResourceVariableSaveable, self).__init__(var, [spec], name)

  def restore(self, restored_tensors, restored_shapes):
    """Restores tensors. Raises ValueError if incompatible shape found."""
    restored_tensor = restored_tensors[0]
    if restored_shapes is not None:
      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
    # Copy the restored tensor to the variable's device.
    with ops.device(self._var_device):
      restored_tensor = array_ops.identity(restored_tensor)
      try:
        assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle(
            self.handle_op, self._var_shape, restored_tensor)
      except ValueError as e:
        raise ValueError(
            f"Received incompatible tensor with shape {restored_tensor.shape} "
            f"when attempting to restore variable with shape {self._var_shape} "
            f"and name {self.name}.") from e
      return assigned_variable


def _tensor_comes_from_variable(v):
  return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS


def saveable_objects_for_op(op, name):
  """Create `SaveableObject`s from an operation.

  Args:
    op: A variable, operation, or SaveableObject to coerce into a
      SaveableObject.
    name: A string name for the SaveableObject.

  Yields:
    `SaveableObject`s which together save/restore `op`.

  Raises:
    TypeError: If `name` is not a string.
    ValueError: For operations with no known conversion to SaveableObject.
  """
  if not isinstance(name, six.string_types):
    raise TypeError(
        "names_to_saveables must be a dict mapping string names to "
        f"trackable operations. Name is not a string: {name}")
  if isinstance(op, saveable_object.SaveableObject):
    yield op
  elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
    if isinstance(op, variables.PartitionedVariable):
      op = list(op)
    # A set of slices.
    slice_name = None
    # pylint: disable=protected-access
    for variable in op:
      if isinstance(variable, saveable_object.SaveableObject):
        yield variable
        continue
      if not isinstance(variable, variables.Variable):
        raise ValueError(f"Slices must all be Variables: {variable}")
      if not variable._save_slice_info:
        raise ValueError(f"Slices must all be slices: {variable}")
      if slice_name is None:
        slice_name = variable._save_slice_info.full_name
      elif slice_name != variable._save_slice_info.full_name:
        raise ValueError(
            f"Slices must all be from the same tensor: {slice_name} != "
            f"{variable._save_slice_info.full_name}")
      if variable.op.type in ["Variable", "VariableV2",
                              "AutoReloadVariable"]:
        yield ReferenceVariableSaveable(
            variable, variable._save_slice_info.spec, name)
      else:
        yield ResourceVariableSaveable(variable, variable._save_slice_info.spec,
                                       name)
    # pylint: enable=protected-access
  elif isinstance(op, trackable.Trackable) and not isinstance(
      op, variables.Variable):
    # pylint: disable=protected-access
    for attr, factory in saveable_objects_from_trackable(op).items():
      if attr == trackable.VARIABLE_VALUE_KEY:
        # Keep original name for classes masquerading as variables.
        full_name = name
      else:
        full_name = name + "_" + attr
      op = (factory(full_name) if callable(factory) else factory)
      for op in saveable_objects_for_op(op, op.name):
        yield op
    # pylint: enable=protected-access
  else:
    # A variable or tensor.
    if isinstance(op, resource_variable_ops.BaseResourceVariable):
      if op._in_graph_mode:  # pylint: disable=protected-access
        variable = op._graph_element  # pylint: disable=protected-access
      else:
        variable = op
      yield ResourceVariableSaveable(variable, "", name)
    else:
      if context.executing_eagerly():
        raise ValueError("Can only save/restore ResourceVariables when "
                         f"executing eagerly, got type: {type(op)}.")

      variable = ops.convert_to_tensor(op, as_ref=True)
      if not _tensor_comes_from_variable(variable):
        raise TypeError(
            "names_to_saveables must be a dict mapping string "
            f"names to Tensors/Variables. Not a variable: {variable}")
      if variable.op.type in ["Variable", "VariableV2",
                              "AutoReloadVariable"]:
        yield ReferenceVariableSaveable(variable, "", name)
      else:
        yield ResourceVariableSaveable(variable, "", name)


def op_list_to_dict(op_list, convert_variable_to_tensor=True):
  """Create a dictionary of names to operation lists.

  Args:
    op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
    convert_variable_to_tensor: Whether or not to convert single Variables
      with no slice info into Tensors.

  Returns:
    A dictionary of names to the operations that must be saved under
    that name.  Variables with save_slice_info are grouped together under the
    same key in no particular order.

  Raises:
    TypeError: If the type of op_list or its elements is not supported.
    ValueError: If at least two saveables share the same name.
  """
  if not isinstance(op_list, (list, tuple, set)):
    raise TypeError("Variables to save should be passed in a dict or a "
                    f"list. Got {op_list}")
  # List casting is necessary to support sets.
  op_list = nest.flatten(list(op_list))
  # When ResourceVariables are converted to Tensors, read ops are added to the
  # graph. Sorting the op_list ensures that the resulting graph is always
  # constructed in a deterministic way:
  op_list = sorted(op_list, key=lambda x: x.name)
  names_to_saveables = {}
  # pylint: disable=protected-access
  for var in op_list:
    resource_or_ref_variable = (
        isinstance(var, resource_variable_ops.BaseResourceVariable) or
        isinstance(var, variables.RefVariable))

    if isinstance(var, saveable_object.SaveableObject):
      names_to_saveables[var.name] = var
    elif isinstance(var, variables.PartitionedVariable):
      if var.name in names_to_saveables:
        raise ValueError(
            f"At least two variables have the same name: {var.name}")
      names_to_saveables[var.name] = var
    elif isinstance(var, variables.Variable) and var._save_slice_info:
      name = var._save_slice_info.full_name
      if name in names_to_saveables:
        if not isinstance(names_to_saveables[name], list):
          raise ValueError("Mixing slices and non-slices with the same name: "
                           f"{name}")
        names_to_saveables[name].append(var)
      else:
        names_to_saveables[name] = [var]
    elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable:
      trackable_saveables = [
          (factory() if callable(factory) else factory)
          for factory in saveable_objects_from_trackable(var).values()]
      names_to_saveables.update(
          op_list_to_dict(trackable_saveables))
    else:
      # Variables (reference and resource) have an _in_graph_mode property
      # indicating whether they were created in a graph building context. We
      # also get Tensors when graph building, which do not have this property.
      if not getattr(var, "_in_graph_mode", True):
        if not isinstance(var, resource_variable_ops.BaseResourceVariable):
          raise ValueError(
              "Can only save/restore ResourceVariables when eager execution "
              f"is enabled. Got type: {type(var)}.")
        set_var = names_to_saveables.setdefault(var._shared_name, var)
        if set_var is not var:
          raise ValueError(
              "Two different ResourceVariable objects with the same "
              f"shared_name '{var._shared_name}' were passed to the Saver. This"
              " likely means that they were created in different Graphs or "
              "isolated contexts, and may not be checkpointed together.")
      else:
        if convert_variable_to_tensor:
          if isinstance(var, resource_variable_ops.BaseResourceVariable):
            var = var._graph_element  # pylint: disable=protected-access
          else:
            var = ops.convert_to_tensor(var, as_ref=True)
          if not _tensor_comes_from_variable(var):
            raise TypeError(f"Variable to save is not a Variable: {var}")
        if var.op.type == "ReadVariableOp":
          name = var.op.inputs[0].op.name
        else:
          name = var.op.name
        if name in names_to_saveables:
          raise ValueError(f"At least two variables have the same name: {name}")
        names_to_saveables[name] = var

    # pylint: enable=protected-access
  return names_to_saveables


def _add_saveable(saveables, seen_ops, saveable):
  """Adds the saveable to the saveables list.

  Args:
    saveables: List to append the SaveableObject to.
    seen_ops: Set of the ops of the saveables already processed.  Used to
      check that each saveable is only saved once.
    saveable: The saveable.

  Raises:
    ValueError: If the saveable has already been processed.
  """
  if saveable.op is not None and saveable.op in seen_ops:
    raise ValueError("The same saveable will be restored with two names: "
                     f"{saveable.name}")
  saveables.append(saveable)
  seen_ops.add(saveable.op)


def validate_and_slice_inputs(names_to_saveables):
  """Returns the variables and names that will be used for a Saver.

  Args:
    names_to_saveables: A dict (k, v) where k is the name of an operation and
       v is an operation to save or a BaseSaverBuilder.Saver.

  Returns:
    A list of SaveableObjects.

  Raises:
    TypeError: If any of the keys are not strings or any of the
      values are not one of Tensor or Variable or a trackable operation.
    ValueError: If the same operation is given in more than one value
      (this also applies to slices of SlicedVariables).
  """
  if not isinstance(names_to_saveables, dict):
    names_to_saveables = op_list_to_dict(names_to_saveables)

  saveables = []
  seen_ops = object_identity.ObjectIdentitySet()
  for name, op in sorted(names_to_saveables.items(),
                         # Avoid comparing ops, sort only by name.
                         key=lambda x: x[0]):
    for converted_saveable_object in saveable_objects_for_op(op, name):
      _add_saveable(saveables, seen_ops, converted_saveable_object)
  return saveables


def trace_save_restore_functions(saveable_factory, obj):
  """Traces save and restore functions."""
  if is_factory_for_restored_saveable_object(saveable_factory):
    return (saveable_factory.keywords["save_function"],
            saveable_factory.keywords["restore_function"])

  saveables = []  # Store the saveables in a data structure accessible to both
  # the save and restore functions.

  @def_function.function(
      input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
  def save_fn(checkpoint_key):
    maybe_saveable = saveable_factory(name=checkpoint_key)
    if isinstance(maybe_saveable, saveable_object.SaveableObject):
      maybe_saveable = [maybe_saveable]
    saveables[:] = maybe_saveable

    # Return list of all SaveSpecs created by the factory.
    ret = []
    for saveable in saveables:
      for spec in saveable.specs:
        ret.append({"name": spec.name, "tensor": spec.tensor,
                    "slice_spec": spec.slice_spec})
    return ret

  concrete_save = save_fn.get_concrete_function()

  # The SaveableObjects are produced when `save_fn` is traced.
  saveables = validate_saveables_for_saved_model(saveables, obj)
  if not saveables:
    return None, None

  # Use the SaveSpecs to define the input signature of the restore function.
  restored_type_specs = []
  tensor_structure = []
  for saveable in saveables:
    saveable_tensor_structure = []
    tensor_structure.append(saveable_tensor_structure)
    for spec in saveable.specs:
      restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
      saveable_tensor_structure.append(spec.name)

  @def_function.function(input_signature=restored_type_specs)
  def restore_fn(*restored_tensors):
    structured_restored_tensors = nest.pack_sequence_as(
        tensor_structure, restored_tensors)
    for saveable, restored_tensors in zip(saveables,
                                          structured_restored_tensors):
      saveable.restore(restored_tensors, restored_shapes=None)
    return 1  # Return dummy tensor

  concrete_restore = restore_fn.get_concrete_function()
  return concrete_save, concrete_restore


def validate_saveables_for_saved_model(saveables, obj):
  """Makes sure SaveableObjects are compatible with SavedModel."""
  if isinstance(obj, python_state.PythonState):
    logging.warn(
        f"Note that object {obj} stores python values into the checkpoint. "
        "These values will not be restored when loading the SavedModel "
        "into python.")
    return []
  if any(isinstance(saveable, trackable.NoRestoreSaveable)
         for saveable in saveables):
    return []
  return saveables


class RestoredSaveableObject(saveable_object.SaveableObject):
  """SaveableObject restored from SavedModel using the traced save/restore."""

  def __init__(self, save_function, restore_function, name):
    self.save_function = save_function
    self.restore_function = restore_function

    if tensor_util.is_tf_type(name):
      name_tensor = name
    else:
      with ops.init_scope():
        name_tensor = constant_op.constant(name)
    tensors = save_function(name_tensor)
    specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"])
             for x in tensors]
    super(RestoredSaveableObject, self).__init__(None, specs, name)

  def restore(self, restored_tensors, restored_shapes):
    del restored_shapes  # unused
    return self.restore_function(
        *[restored_tensors[i] for i in range(len(self.specs))])


def restored_saved_object_factory(save_function, restore_function):
  return functools.partial(RestoredSaveableObject,
                           save_function=save_function,
                           restore_function=restore_function)


def create_saveable_object(factory, name, call_with_mapped_captures):
  """Creates a SaveableObject while potentially in a different graph.

  When creating the frozen saver for SavedModel, the save and restore ops are
  placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
  save and restore, the function captures must be mapped to the new graph.

  Args:
    factory: Factory method for creating the SaveableObject.
    name: Checkpoint key of this SaveableObject.
    call_with_mapped_captures: Helper that calls a tf.function while remapping
      the captures.

  Returns:
    a SaveableObject.
  """
  if (call_with_mapped_captures is None or
      not is_factory_for_restored_saveable_object(factory)):
    return factory(name=name)

  concrete_save_fn = factory.keywords["save_function"]

  def save_fn(name):
    return call_with_mapped_captures(concrete_save_fn, [name])

  concrete_restore_fn = factory.keywords["restore_function"]

  def restore_fn(*restored_tensors):
    return call_with_mapped_captures(concrete_restore_fn, restored_tensors)

  return factory(save_function=save_fn, restore_function=restore_fn, name=name)


def is_factory_for_restored_saveable_object(factory):
  return (isinstance(factory, functools.partial) and
          factory.func is RestoredSaveableObject)


@tf_export("__internal__.tracking.saveable_objects_from_trackable", v1=[])
def saveable_objects_from_trackable(obj):
  """Returns SaveableObject factory dict from a Trackable."""
  if isinstance(obj, python_state.PythonState):
    return {
        "py_state":
            functools.partial(
                _PythonStringStateSaveable,
                state_callback=obj.serialize,
                restore_callback=obj.deserialize)
    }
  if trackable_has_serialize_to_tensor(obj):

    def create_saveable(name=""):
      return TrackableSaveable(obj, name)

    return {trackable_utils.SERIALIZE_TO_TENSORS_NAME: create_saveable}
  else:
    return obj._gather_saveables_for_checkpoint()  # pylint: disable=protected-access


class TrackableSaveable(saveable_object.SaveableObject):
  """A SaveableObject that defines `Trackable` checkpointing steps."""

  def __init__(self, obj, name):
    self._trackable = obj
    tensor_dict = obj._serialize_to_tensors()  # pylint: disable=protected-access
    specs = []
    self._local_names = []
    self._prefix = saveable_compat.get_saveable_name(self._trackable) or ""
    for tensor_name, tensor in tensor_dict.items():
      self._local_names.append(tensor_name)
      spec_name = name + trackable_utils.escape_local_name(tensor_name)
      specs.append(saveable_object.SaveSpec(tensor, "", spec_name))
    super(TrackableSaveable, self).__init__(obj, specs, name)

  def restore(self, restored_tensors, restored_shapes):
    del restored_shapes  # Unused.
    restored_tensor_dict = {}
    for n, local_name in enumerate(self._local_names):
      restored_tensor_dict[local_name] = restored_tensors[n]

    def restore_from_tensors():
      self._trackable._restore_from_tensors(restored_tensor_dict)  # pylint: disable=protected-access
      # In graph mode, this wrapper function is converted into a tf.function,
      # and to ensure that _restore_from_tensors is executed, there must be at
      # least one returned tensor. `_restore_from_tensors` may return zero
      # tensors so create a dummy constant here.
      return constant_op.constant(1)

    if not ops.executing_eagerly_outside_functions():
      restore_from_tensors = def_function.function(restore_from_tensors)
    return restore_from_tensors()

  def get_proto_names_and_checkpoint_keys(self):
    return [(self._prefix + local_name, spec.name)
            for local_name, spec in zip(self._local_names, self.specs)]


class _PythonStringStateSaveable(saveable_object.SaveableObject):
  """Saves Python state in a checkpoint."""

  def __init__(self, name, state_callback, restore_callback):
    """Configure saving.

    Args:
      name: The checkpoint key to write to.
      state_callback: A function taking no arguments which returns a string.
        This function is run every time a checkpoint is written.
      restore_callback: A function taking a Python string, used to restore
        state.
    """

    def _state_callback_wrapper():
      with ops.init_scope():
        return state_callback()

    self._state_callback = _state_callback_wrapper
    self._restore_callback = restore_callback
    with ops.device("/cpu:0"):
      self._save_string = constant_op.constant("", dtype=dtypes.string)
    spec = saveable_object.SaveSpec(
        self._save_string, "", name, dtype=dtypes.string)
    super(_PythonStringStateSaveable, self).__init__(self._save_string, [spec],
                                                     name)

  def feed_dict_additions(self):
    """When running a graph, indicates fresh state to feed."""
    return {self._save_string: self._state_callback()}

  def freeze(self):
    """Create a frozen `SaveableObject` which saves the current state."""

    def _constant_state():
      return constant_op.constant(self._state_callback(), dtype=dtypes.string)

    return trackable.NoRestoreSaveable(
        tensor=_constant_state,
        dtype=dtypes.string,
        name=self.name,
        device="cpu:0")


def trackable_has_serialize_to_tensor(obj):
  # pylint: disable=protected-access
  obj_serialize_fn = obj._serialize_to_tensors
  if hasattr(obj_serialize_fn, "__func__"):
    obj_serialize_fn = obj_serialize_fn.__func__
  return trackable.Trackable._serialize_to_tensors != obj_serialize_fn
  # pylint: enable=protected-access


class SaveableCompatibilityConverter(trackable.Trackable):
  """Converts object's `SaveableObjects` to functions used in TF2 checkpointing.

  A class that converts a Trackable object's `SaveableObjects` to save and
  restore functions with the same signatures as
  `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`.
  This class also produces a method for filling the object proto.
  """

  __slots__ = ("_obj", "_cached_saveables")

  def __init__(self, obj):
    """Constructor.

    Args:
      obj: A Trackable object which implements the deprecated
        `_gather_saveables_for_checkpoint`.
    """
    self._obj = obj
    self._cached_saveables = None

    _ = self._saveables  # Generate cached saveables when converter is created.

  @property
  def _saveables(self):
    """Returns a list of SaveableObjects generated from the Trackable object."""
    if self._cached_saveables is not None:
      return self._cached_saveables

    self._cached_saveables = []
    saveable_names = []
    for name, saveable_factory in (
        saveable_objects_from_trackable(self._obj).items()):
      if callable(saveable_factory):
        maybe_saveable = create_saveable_object(
            saveable_factory, name, call_with_mapped_captures=None)
      else:
        maybe_saveable = saveable_factory
      if isinstance(maybe_saveable, saveable_object.SaveableObject):
        saveables = (maybe_saveable,)
      else:
        saveables = tuple(saveable_objects_for_op(op=maybe_saveable, name=name))
      self._cached_saveables.extend(saveables)
      saveable_names.extend([name] * len(saveables))

    if not saveable_compat.force_checkpoint_conversion_enabled():
      # Run an extra step to validate that the converter can be used without
      # changing the checkpoint metadata.
      self._maybe_apply_legacy_decorator(saveable_names)

    return self._cached_saveables

  def _maybe_apply_legacy_decorator(self, saveable_names):
    # Check the spec names. If there are multiple specs with different names
    # under the same saveable, then the this indicates that a decorator must be
    # used to ensure checkpoint equality under the new checkpoint
    # implementation. See the docstring `legacy_saveable_name` for details.
    for saveable in self._cached_saveables:
      spec_names = set(spec for spec in saveable.specs)

      if len(spec_names) == 1:
        continue  # Decorator not needed.

      if len(set(saveable_names)) > 1:
        # An edge case not handled by the legacy decorator has been encountered.
        raise saveable_compat.CheckpointConversionError

      saveable_compat.legacy_saveable_name(saveable_names[0])(self)

  def _serialize_to_tensors(self):
    """Returns a dict of tensors to serialize."""
    tensor_dict = {}
    for saveable in self._saveables:
      for spec in saveable.specs:
        tensor = spec.tensor
        if spec.slice_spec:
          tensor_dict[spec.name][spec.slice_spec] = tensor
        else:
          tensor_dict[spec.name] = tensor
    return tensor_dict

  def _restore_from_tensors(self, restored_tensors):
    """Returns the restore ops defined in the Saveables."""
    # Map restored tensors to the corresponding SaveableObjects, then call
    # restore. There must be an exact match between restored tensors and the
    # expected attributes.
    expected_keys = []
    for saveable in self._saveables:
      expected_keys.extend(spec.name for spec in saveable.specs)
    if set(expected_keys) != restored_tensors.keys():
      raise ValueError(f"Could not restore object {self._obj} because not all "
                       "expected tensors were in the checkpoint."
                       f"\n\tExpected: {expected_keys}"
                       f"\n\tGot: {list(restored_tensors.keys())}")

    restore_ops = {}
    for saveable in self._saveables:
      saveable_restored_tensors = []
      for spec in saveable.specs:
        if spec.slice_spec:
          saveable_restored_tensors.append(
              restored_tensors[spec.name][spec.slice_spec])
        else:
          saveable_restored_tensors.append(restored_tensors[spec.name])

      restore_ops[saveable.name] = saveable.restore(
          saveable_restored_tensors, restored_shapes=None)
    return restore_ops
