# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Options for saving SavedModels."""

import enum
import six

from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export


@tf_export("saved_model.experimental.VariablePolicy")
class VariablePolicy(enum.Enum):
  """Enum defining options for variable handling when saving.

  NONE
    No policy applied: Distributed variables are saved as one variable, with no
    device attached.

  SAVE_VARIABLE_DEVICES
    When saving variables, also save their device assignment.
    This is useful if one wants to hardcode devices in saved models, but it also
    makes them non-portable if soft device placement is disabled (more details
    in `tf.config.set_soft_device_placement`). This is currently not
    fully supported by `saved_model.load`, and is mainly intended to be used
    when one will be reading the saved model at a lower API level. In the
    example below, the graph saved by the call to `saved_model.save` will have
    the variable devices correctly specified:
    ```python
    exported = tf.train.Checkpoint()
    with tf.device('/GPU:0'):
      exported.x_gpu = tf.Variable(1.0)
    with tf.device('/CPU:0'):
      exported.x_cpu = tf.Variable(1.0)
    tf.saved_model.save(exported, export_dir,
        options = tf.saved_model.SaveOptions(
            experimental_variable_policy=
              tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
    ```
    Distributed variables are still saved as one variable under this policy.

  EXPAND_DISTRIBUTED_VARIABLES
    Distributed variables will be saved with information about their components,
    allowing for their restoration on load. Also, the saved graph will contain
    references to those variables. This is useful when one wants to use the
    model for training in environments where the original distribution strategy
    is not available.
  """

  NONE = None

  SAVE_VARIABLE_DEVICES = "save_variable_devices"

  EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"

  def _save_variable_devices(self):
    """Checks whether variable devices should be saved."""
    return self != VariablePolicy.NONE

  def _expand_distributed_variables(self):
    """Checks whether distributed variables should be expanded."""
    return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES

  @staticmethod
  def from_obj(obj):
    """Tries to convert `obj` to a VariablePolicy instance."""
    if obj is None:
      return VariablePolicy.NONE
    if isinstance(obj, VariablePolicy):
      return obj
    key = str(obj).lower()
    for policy in VariablePolicy:
      if key == policy.value:
        return policy
    raise ValueError(f"Received invalid VariablePolicy value: {obj}.")


@tf_export("saved_model.SaveOptions")
class SaveOptions(object):
  """Options for saving to SavedModel.

  This function may be used in the `options` argument in functions that
  save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
  """

  # Define object attributes in __slots__ for improved memory and performance.
  __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
               "experimental_io_device", "experimental_variable_policy",
               "experimental_custom_gradients")

  def __init__(self,
               namespace_whitelist=None,
               save_debug_info=False,
               function_aliases=None,
               experimental_io_device=None,
               experimental_variable_policy=None,
               experimental_custom_gradients=True):
    """Creates an object that stores options for SavedModel saving.

    Args:
      namespace_whitelist: List of strings containing op namespaces to whitelist
        when saving a model. Saving an object that uses namespaced ops must
        explicitly add all namespaces to the whitelist. The namespaced ops must
        be registered into the framework when loading the SavedModel. If no
        whitelist is provided, all namespaced ops will be allowed.
      save_debug_info: Boolean indicating whether debug information is saved. If
        True, then a debug/saved_model_debug_info.pb file will be written with
        the contents of a GraphDebugInfo binary protocol buffer containing stack
        trace information for all ops and functions that are saved.
      function_aliases: Python dict. Mapping from string to object returned by
        @tf.function. A single tf.function can generate many ConcreteFunctions.
        If a downstream tool wants to refer to all concrete functions generated
        by a single tf.function you can use the `function_aliases` argument to
        store a map from the alias name to all concrete function names.
        E.g.

        >>> class Adder(tf.Module):
        ...   @tf.function
        ...   def double(self, x):
        ...     return x + x

        >>> model = Adder()
        >>> model.double.get_concrete_function(
        ...   tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input"))
        >>> model.double.get_concrete_function(
        ...   tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))

        >>> options = tf.saved_model.SaveOptions(
        ...   function_aliases={'double': model.double})
        >>> tf.saved_model.save(model, '/tmp/adder', options=options)

      experimental_io_device: string. Applies in a distributed setting.
        Tensorflow device to use to access the filesystem. If `None` (default)
        then for each variable the filesystem is accessed from the CPU:0 device
        of the host where that variable is assigned. If specified, the
        filesystem is instead accessed from that device for all variables.

        This is for example useful if you want to save to a local directory,
        such as "/tmp" when running in a distributed setting. In that case pass
        a device for the host where the "/tmp" directory is accessible.
      experimental_variable_policy: The policy to apply to variables when
        saving. This is either a `saved_model.experimental.VariablePolicy` enum
        instance or one of its value strings (case is not important). See that
        enum documentation for details. A value of `None` corresponds to the
        default policy.
      experimental_custom_gradients: Boolean. When True, will save traced
        gradient functions for the functions decorated by `tf.custom_gradient`.
        Defaults to `True`.
    """
    self.namespace_whitelist = _validate_namespace_whitelist(
        namespace_whitelist)
    self.save_debug_info = save_debug_info
    self.function_aliases = function_aliases if function_aliases else dict()
    self.experimental_custom_gradients = experimental_custom_gradients
    self.experimental_io_device = experimental_io_device
    self.experimental_variable_policy = (
        VariablePolicy.from_obj(experimental_variable_policy))


def _validate_namespace_whitelist(namespace_whitelist):
  """Validates namespace whitelist argument."""
  if namespace_whitelist is None:
    return None
  if not isinstance(namespace_whitelist, list):
    raise TypeError("`namespace_whitelist` must be a list of strings. Got: "
                    f"{namespace_whitelist} with type "
                    f"{type(namespace_whitelist)}.")

  processed = []
  for namespace in namespace_whitelist:
    if not isinstance(namespace, six.string_types):
      raise ValueError("Whitelisted namespace must be a string. Got: "
                       f"{namespace} of type {type(namespace)}.")
    processed.append(compat.as_str(namespace))
  return processed
