# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SignatureDef method name utility functions.

Utility functions for manipulating signature_def.method_names.
"""

from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader_impl as loader
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export


# TODO(jdchung): Consider integrated this into the saved_model_cli so that users
# could do this from the command line directly.
@tf_export(v1=["saved_model.signature_def_utils.MethodNameUpdater"])
class MethodNameUpdater(object):
  """Updates the method name(s) of the SavedModel stored in the given path.

  The `MethodNameUpdater` class provides the functionality to update the method
  name field in the signature_defs of the given SavedModel. For example, it
  can be used to replace the `predict` `method_name` to `regress`.

  Typical usages of the `MethodNameUpdater`
  ```python
  ...
  updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater(
      export_dir)
  # Update all signature_defs with key "foo" in all meta graph defs.
  updater.replace_method_name(signature_key="foo", method_name="regress")
  # Update a single signature_def with key "bar" in the meta graph def with
  # tags ["serve"]
  updater.replace_method_name(signature_key="bar", method_name="classify",
                              tags="serve")
  updater.save(new_export_dir)
  ```

  Note: This function will only be available through the v1 compatibility
  library as tf.compat.v1.saved_model.builder.MethodNameUpdater.
  """

  def __init__(self, export_dir):
    """Creates an MethodNameUpdater object.

    Args:
      export_dir: Directory containing the SavedModel files.

    Raises:
      IOError: If the saved model file does not exist, or cannot be successfully
      parsed.
    """
    self._export_dir = export_dir
    self._saved_model = loader.parse_saved_model(export_dir)

  def replace_method_name(self, signature_key, method_name, tags=None):
    """Replaces the method_name in the specified signature_def.

    This will match and replace multiple sig defs iff tags is None (i.e when
    multiple `MetaGraph`s have a signature_def with the same key).
    If tags is not None, this will only replace a single signature_def in the
    `MetaGraph` with matching tags.

    Args:
      signature_key: Key of the signature_def to be updated.
      method_name: new method_name to replace the existing one.
      tags: A tag or sequence of tags identifying the `MetaGraph` to update. If
          None, all meta graphs will be updated.
    Raises:
      ValueError: if signature_key or method_name are not defined or
          if no metagraphs were found with the associated tags or
          if no meta graph has a signature_def that matches signature_key.
    """
    if not signature_key:
      raise ValueError("`signature_key` must be defined.")
    if not method_name:
      raise ValueError("`method_name` must be defined.")

    if (tags is not None and not isinstance(tags, list)):
      tags = [tags]
    found_match = False
    for meta_graph_def in self._saved_model.meta_graphs:
      if tags is None or set(tags) == set(meta_graph_def.meta_info_def.tags):
        if signature_key not in meta_graph_def.signature_def:
          raise ValueError(
              f"MetaGraphDef associated with tags {tags} "
              f"does not have a signature_def with key: '{signature_key}'. "
              "This means either you specified the wrong signature key or "
              "forgot to put the signature_def with the corresponding key in "
              "your SavedModel.")
        meta_graph_def.signature_def[signature_key].method_name = method_name
        found_match = True

    if not found_match:
      raise ValueError(
          f"MetaGraphDef associated with tags {tags} could not be found in "
          "SavedModel. This means either you specified invalid tags or your "
          "SavedModel does not have a MetaGraphDef with the specified tags.")

  def save(self, new_export_dir=None):
    """Saves the updated `SavedModel`.

    Args:
      new_export_dir: Path where the updated `SavedModel` will be saved. If
          None, the input `SavedModel` will be overriden with the updates.

    Raises:
      errors.OpError: If there are errors during the file save operation.
    """

    is_input_text_proto = file_io.file_exists(
        file_io.join(
            compat.as_bytes(self._export_dir),
            compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)))
    if not new_export_dir:
      new_export_dir = self._export_dir

    if is_input_text_proto:
      # TODO(jdchung): Add a util for the path creation below.
      path = file_io.join(
          compat.as_bytes(new_export_dir),
          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
      file_io.write_string_to_file(path, str(self._saved_model))
    else:
      path = file_io.join(
          compat.as_bytes(new_export_dir),
          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
      file_io.write_string_to_file(
          path, self._saved_model.SerializeToString(deterministic=True))
    tf_logging.info("SavedModel written to: %s", compat.as_text(path))
