# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python dataset sparse tensor utility functions."""
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import sparse_ops


def any_sparse(classes):
  """Checks for sparse tensor.

  Args:
    classes: a structure of objects that identify the dataset item classes

  Returns:
    `True` if `classes` contains a sparse tensor type and `False` otherwise.
  """
  return any(c is sparse_tensor.SparseTensor for c in nest.flatten(classes))


def as_dense_shapes(shapes, classes):
  """Converts sparse tensor shapes to their physical shapes.

  Args:
    shapes: a structure of shapes to convert.
    classes: a structure of objects that identify the dataset item classes

  Returns:
    a structure matching the nested structure of `shapes`, containing
    `tensor_shape.unknown_shape()` at positions where `classes` contains
    `tf.sparse.SparseTensor` and matching contents of `shapes` otherwise
  """
  ret = nest.pack_sequence_as(shapes, [
      tensor_shape.unknown_shape() if c is sparse_tensor.SparseTensor else shape
      for shape, c in zip(nest.flatten(shapes), nest.flatten(classes))
  ])
  return ret


def as_dense_types(types, classes):
  """Converts sparse tensor types to `dtypes.variant`.

  Args:
    types: a structure of types to convert.
    classes: a structure of objects that identify the dataset item classes

  Returns:
    a structure matching the nested structure of `types`, containing
    `dtypes.variant` at positions where `classes` contains
    `tf.sparse.SparseTensor` and matching contents of `types` otherwise
  """
  ret = nest.pack_sequence_as(types, [
      dtypes.variant if c is sparse_tensor.SparseTensor else ty
      for ty, c in zip(nest.flatten(types), nest.flatten(classes))
  ])
  return ret


def deserialize_sparse_tensors(tensors, types, shapes, classes):
  """Deserializes sparse tensors.

  Args:
    tensors: a structure of tensors to deserialize.
    types: a structure that holds information about types of `tensors`
    shapes: a structure that holds information about shapes of `tensors`
    classes: a structure of objects that identify the dataset item classes

  Returns:
    `tensors` with any serialized sparse tensors replaced by their deserialized
    version.
  """
  ret = nest.pack_sequence_as(types, [
      sparse_ops.deserialize_sparse(tensor, dtype=ty, rank=shape.ndims)
      if c is sparse_tensor.SparseTensor else tensor
      for (tensor, ty, shape, c) in zip(
          nest.flatten(tensors), nest.flatten(types), nest.flatten(shapes),
          nest.flatten(classes))
  ])
  return ret


def get_classes(tensors):
  """Gets classes for a structure of tensors.

  Args:
    tensors: the tensor structure to get classes for.

  Returns:
    a structure matching the nested structure of `tensors`, containing
    `tf.sparse.SparseTensor` at positions where `tensors` contains a sparse
    tensor and `tf.Tensor` otherwise.
  """
  return nest.pack_sequence_as(tensors, [
      sparse_tensor.SparseTensor
      if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor
      for tensor in nest.flatten(tensors)
  ])


def serialize_many_sparse_tensors(tensors):
  """Serializes many sparse tensors into a batch.

  Args:
    tensors: a tensor structure to serialize.

  Returns:
    `tensors` with any sparse tensors replaced by the serialized batch.
  """

  ret = nest.pack_sequence_as(tensors, [
      sparse_ops.serialize_many_sparse(tensor, out_type=dtypes.variant)
      if sparse_tensor.is_sparse(tensor) else tensor
      for tensor in nest.flatten(tensors)
  ])
  return ret


def serialize_sparse_tensors(tensors):
  """Serializes sparse tensors.

  Args:
    tensors: a tensor structure to serialize.

  Returns:
    `tensors` with any sparse tensors replaced by their serialized version.
  """

  ret = nest.pack_sequence_as(tensors, [
      sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant)
      if isinstance(tensor, sparse_tensor.SparseTensor) else tensor
      for tensor in nest.flatten(tensors)
  ])
  return ret
