# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python API for save and loading a dataset."""

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export

COMPRESSION_GZIP = "GZIP"
COMPRESSION_SNAPPY = "NONE"
DATASET_SPEC_FILENAME = "dataset_spec.pb"
# TODO(b/176933539): Use the regular import.
# TODO(b/238903802): Use TypeSpec serialization methods directly.
nested_structure_coder = lazy_loader.LazyLoader(
    "nested_structure_coder", globals(),
    "tensorflow.python.saved_model.nested_structure_coder")


@tf_export("data.experimental.save", v1=[])
@deprecation.deprecated(None, "Use `tf.data.Dataset.save(...)` instead.")
def save(dataset,
         path,
         compression=None,
         shard_func=None,
         checkpoint_args=None):
  """Saves the content of the given dataset.

  Example usage:

  >>> import tempfile
  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
  >>> # Save a dataset
  >>> dataset = tf.data.Dataset.range(2)
  >>> tf.data.experimental.save(dataset, path)
  >>> new_dataset = tf.data.experimental.load(path)
  >>> for elem in new_dataset:
  ...   print(elem)
  tf.Tensor(0, shape=(), dtype=int64)
  tf.Tensor(1, shape=(), dtype=int64)

  The saved dataset is saved in multiple file "shards". By default, the dataset
  output is divided to shards in a round-robin fashion but custom sharding can
  be specified via the `shard_func` function. For example, you can save the
  dataset to using a single shard as follows:

  ```python
  dataset = make_dataset()
  def custom_shard_func(element):
    return np.int64(0)
  dataset = tf.data.experimental.save(
      path="/path/to/data", ..., shard_func=custom_shard_func)
  ```

  To enable checkpointing, pass in `checkpoint_args` to the `save` method
  as follows:

  ```python
  dataset = tf.data.Dataset.range(100)
  save_dir = "..."
  checkpoint_prefix = "..."
  step_counter = tf.Variable(0, trainable=False)
  checkpoint_args = {
    "checkpoint_interval": 50,
    "step_counter": step_counter,
    "directory": checkpoint_prefix,
    "max_to_keep": 20,
  }
  dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)
  ```

  NOTE: The directory layout and file format used for saving the dataset is
  considered an implementation detail and may change. For this reason, datasets
  saved through `tf.data.experimental.save` should only be consumed through
  `tf.data.experimental.load`, which is guaranteed to be backwards compatible.

  Args:
    dataset: The dataset to save.
    path: Required. A directory to use for saving the dataset.
    compression: Optional. The algorithm to use to compress data when writing
      it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
    shard_func: Optional. A function to control the mapping of dataset elements
      to file shards. The function is expected to map elements of the input
      dataset to int64 shard IDs. If present, the function will be traced and
      executed as graph computation.
    checkpoint_args: Optional args for checkpointing which will be passed into
      the `tf.train.CheckpointManager`. If `checkpoint_args` are not specified,
      then checkpointing will not be performed. The `save()` implementation
      creates a `tf.train.Checkpoint` object internally, so users should not
      set the `checkpoint` argument in `checkpoint_args`.
  Raises:
    ValueError if `checkpoint` is passed into `checkpoint_args`.
  """
  dataset.save(path, compression, shard_func, checkpoint_args)


@tf_export("data.experimental.load", v1=[])
@deprecation.deprecated(None, "Use `tf.data.Dataset.load(...)` instead.")
def load(path, element_spec=None, compression=None, reader_func=None):
  """Loads a previously saved dataset.

  Example usage:

  >>> import tempfile
  >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
  >>> # Save a dataset
  >>> dataset = tf.data.Dataset.range(2)
  >>> tf.data.experimental.save(dataset, path)
  >>> new_dataset = tf.data.experimental.load(path)
  >>> for elem in new_dataset:
  ...   print(elem)
  tf.Tensor(0, shape=(), dtype=int64)
  tf.Tensor(1, shape=(), dtype=int64)


  If the default option of sharding the saved dataset was used, the element
  order of the saved dataset will be preserved when loading it.

  The `reader_func` argument can be used to specify a custom order in which
  elements should be loaded from the individual shards. The `reader_func` is
  expected to take a single argument -- a dataset of datasets, each containing
  elements of one of the shards -- and return a dataset of elements. For
  example, the order of shards can be shuffled when loading them as follows:

  ```python
  def custom_reader_func(datasets):
    datasets = datasets.shuffle(NUM_SHARDS)
    return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)

  dataset = tf.data.experimental.load(
      path="/path/to/data", ..., reader_func=custom_reader_func)
  ```

  Args:
    path: Required. A path pointing to a previously saved dataset.
    element_spec: Optional. A nested structure of `tf.TypeSpec` objects matching
      the structure of an element of the saved dataset and specifying the type
      of individual element components. If not provided, the nested structure of
      `tf.TypeSpec` saved with the saved dataset is used. Note that this
      argument is required in graph mode.
    compression: Optional. The algorithm to use to decompress the data when
      reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
    reader_func: Optional. A function to control how to read data from shards.
      If present, the function will be traced and executed as graph computation.

  Returns:
    A `tf.data.Dataset` instance.

  Raises:
    FileNotFoundError: If `element_spec` is not specified and the saved nested
      structure of `tf.TypeSpec` can not be located with the saved dataset.
    ValueError: If `element_spec` is not specified and the method is executed
      in graph mode.
  """
  return dataset_ops.Dataset.load(path, element_spec, compression, reader_func)
