# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras Input Tensor used to track functional API Topology."""

import tensorflow.compat.v2 as tf

from keras.utils import object_identity

# isort: off
from tensorflow.python.data.util import structure


# Tensorflow tensors have a maximum rank of 254
# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h )
# So we do not try to infer values for int32 tensors larger than this,
# As they cannot represent shapes.
_MAX_TENSOR_RANK = 254


class KerasTensor:
    """A representation of a Keras in/output during Functional API construction.

    `KerasTensor`s are tensor-like objects that represent the symbolic inputs
    and outputs of Keras layers during Functional model construction. They are
    comprised of the `tf.TypeSpec` of the (Composite)Tensor that will be
    consumed/produced in the corresponding location of the Functional model.

    KerasTensors are intended as a private API, so users should never need to
    directly instantiate `KerasTensor`s.

    **Building Functional Models with KerasTensors**
    `tf.keras.Input` produces `KerasTensor`s that represent the symbolic inputs
    to your model.

    Passing a `KerasTensor` to a `tf.keras.Layer` `__call__` lets the layer know
    that you are building a Functional model. The layer __call__ will
    infer the output signature and return `KerasTensor`s with `tf.TypeSpec`s
    corresponding to the symbolic outputs of that layer call. These output
    `KerasTensor`s will have all of the internal KerasHistory metadata attached
    to them that Keras needs to construct a Functional Model.

    Currently, layers infer the output signature by:
      * creating a scratch `FuncGraph`
      * making placeholders in the scratch graph that match the input typespecs
      * Calling `layer.call` on these placeholders
      * extracting the signatures of the outputs before clearing the scratch
        graph

    (Note: names assigned to KerasTensors by this process are not guaranteed to
    be unique, and are subject to implementation details).

    `tf.nest` methods are used to insure all of the inputs/output data
    structures get maintained, with elements swapped between KerasTensors and
    placeholders.

    In rare cases (such as when directly manipulating shapes using Keras
    layers), the layer may be able to partially infer the value of the output in
    addition to just inferring the signature.
    When this happens, the returned KerasTensor will also contain the inferred
    value information. Follow-on layers can use this information.
    during their own output signature inference.
    E.g. if one layer produces a symbolic `KerasTensor` that the next layer uses
    as the shape of its outputs, partially knowing the value helps infer the
    output shape.

    **Automatically converting TF APIs to layers**:
    If you passing a `KerasTensor` to a TF API that supports dispatching,
    Keras will automatically turn that API call into a lambda
    layer in the Functional model, and return KerasTensors representing the
    symbolic outputs.

    Most TF APIs that take only tensors as input and produce output tensors
    will support dispatching.

    Calling a `tf.function` does not support dispatching, so you cannot pass
    `KerasTensor`s as inputs to a `tf.function`.

    Higher-order APIs that take methods which produce tensors (e.g. `tf.while`,
    `tf.map_fn`, `tf.cond`) also do not currently support dispatching. So, you
    cannot directly pass KerasTensors as inputs to these APIs either. If you
    want to use these APIs inside of a Functional model, you must put them
    inside of a custom layer.

    Args:
      type_spec: The `tf.TypeSpec` for the symbolic input created by
        `tf.keras.Input`, or symbolically inferred for the output
        during a symbolic layer `__call__`.
      inferred_value: (Optional) a non-symbolic static value, possibly partially
        specified, that could be symbolically inferred for the outputs during
        a symbolic layer `__call__`. This will generally only happen when
        grabbing and manipulating `tf.int32` shapes directly as tensors.
        Statically inferring values in this way and storing them in the
        KerasTensor allows follow-on layers to infer output signatures
        more effectively. (e.g. when using a symbolic shape tensor to later
        construct a tensor with that shape).
      name: (optional) string name for this KerasTensor. Names automatically
        generated by symbolic layer `__call__`s are not guaranteed to be unique,
        and are subject to implementation details.
    """

    def __init__(self, type_spec, inferred_value=None, name=None):
        """Constructs a KerasTensor."""
        if not isinstance(type_spec, tf.TypeSpec):
            raise ValueError(
                "KerasTensors must be constructed with a `tf.TypeSpec`."
            )

        self._type_spec = type_spec
        self._inferred_value = inferred_value
        self._name = name

        if not isinstance(type_spec, structure.NoneTensorSpec):
            if not hasattr(type_spec, "shape"):
                raise ValueError(
                    "KerasTensor only supports TypeSpecs that have a shape "
                    f"field; got {type(type_spec).__qualname__}, "
                    "which does not have a shape."
                )
            if not isinstance(type_spec.shape, tf.TensorShape):
                raise TypeError(
                    "KerasTensor requires that wrapped TypeSpec's shape is a "
                    f"TensorShape; got TypeSpec {type(type_spec).__qualname__}"
                    ", whose shape field has unexpected type "
                    f"{type(type_spec.dtype).__qualname__}."
                )

    @property
    def type_spec(self):
        """Returns the `tf.TypeSpec` symbolically inferred for Keras output."""
        return self._type_spec

    @property
    def shape(self):
        """Returns the `TensorShape` symbolically inferred for Keras output."""
        return self._type_spec.shape

    @classmethod
    def from_tensor(cls, tensor):
        """Convert a traced (composite)tensor to a representative
        KerasTensor."""
        if isinstance(tensor, tf.Tensor):
            name = getattr(tensor, "name", None)
            type_spec = tf.type_spec_from_value(tensor)
            inferred_value = None
            if (
                type_spec.dtype == tf.int32
                and type_spec.shape.rank is not None
                and type_spec.shape.rank < 2
            ):
                # If this tensor might be representing shape information,
                # (dtype=int32, rank of 0 or 1, not too large to represent a
                # shape) we attempt to capture any value information
                # tensorflow's shape handling can extract from the current
                # scratch graph.
                #
                # Even though keras layers each trace in their own scratch
                # graph, this shape value info extraction allows us to capture a
                # sizable and useful subset of the C++ shape value inference TF
                # can do if all tf ops appear in the same graph when using shape
                # ops.
                #
                # Examples of things this cannot infer concrete dimensions for
                # that the full single-graph C++ shape inference sometimes can
                # are:
                # * cases where the shape tensor is cast out of int32 before
                #   being manipulated w/ floating point numbers then converted
                #   back
                # * cases where int32 tensors w/ rank >= 2 are manipulated
                #   before being used as a shape tensor
                # * cases where int32 tensors too large to represent shapes are
                #   manipulated to a smaller size before being used as a shape
                #   tensor
                inferred_value = tf.ones(shape=tensor).shape
                if inferred_value.dims:
                    inferred_value = inferred_value.as_list()
                    if len(inferred_value) > _MAX_TENSOR_RANK:
                        inferred_value = None
                else:
                    inferred_value = None

            return KerasTensor(
                type_spec, inferred_value=inferred_value, name=name
            )
        else:
            # Fallback to the generic arbitrary-typespec KerasTensor
            name = getattr(tensor, "name", None)
            type_spec = tf.type_spec_from_value(tensor)
            return cls(type_spec, name=name)

    @classmethod
    def from_type_spec(cls, type_spec, name=None):
        return cls(type_spec=type_spec, name=name)

    def _to_placeholder(self):
        """Convert this KerasTensor to a placeholder in a graph."""
        # If there is an inferred value for this tensor, inject the inferred
        # value
        if self._inferred_value is not None:
            # If we suspect this KerasTensor might be representing a shape
            # tensor, and we were able to extract value information with
            # TensorFlow's shape handling when making the KerasTensor, we
            # construct the placeholder by re-injecting the inferred value
            # information into the graph. We do this injection through the shape
            # of a placeholder, because that allows us to specify
            # partially-unspecified shape values.
            #
            # See the comment on value extraction inside `from_tensor` for more
            # info.
            inferred_value = tf.shape(
                tf.compat.v1.placeholder(
                    shape=self._inferred_value, dtype=tf.int32
                )
            )
            if self.type_spec.shape.rank == 0:
                # `tf.shape` always returns a rank-1, we may need to turn it
                # back to a scalar.
                inferred_value = inferred_value[0]
            return inferred_value

        # Use the generic conversion from typespec to a placeholder.
        def component_to_placeholder(component):
            return tf.compat.v1.placeholder(component.dtype, component.shape)

        return tf.nest.map_structure(
            component_to_placeholder, self.type_spec, expand_composites=True
        )

    def get_shape(self):
        return self.shape

    def __len__(self):
        raise TypeError(
            "Keras symbolic inputs/outputs do not "
            "implement `__len__`. You may be "
            "trying to pass Keras symbolic inputs/outputs "
            "to a TF API that does not register dispatching, "
            "preventing Keras from automatically "
            "converting the API call to a lambda layer "
            "in the Functional Model. This error will also get raised "
            "if you try asserting a symbolic input/output directly."
        )

    @property
    def op(self):
        raise TypeError(
            "Keras symbolic inputs/outputs do not "
            "implement `op`. You may be "
            "trying to pass Keras symbolic inputs/outputs "
            "to a TF API that does not register dispatching, "
            "preventing Keras from automatically "
            "converting the API call to a lambda layer "
            "in the Functional Model."
        )

    def __hash__(self):
        raise TypeError(
            f"Tensors are unhashable (this tensor: {self}). "
            "Instead, use tensor.ref() as the key."
        )

    # Note: This enables the KerasTensor's overloaded "right" binary
    # operators to run when the left operand is an ndarray, because it
    # accords the Tensor class higher priority than an ndarray, or a
    # numpy matrix.
    # In the future explore changing this to using numpy's __numpy_ufunc__
    # mechanism, which allows more control over how Tensors interact
    # with ndarrays.
    __array_priority__ = 100

    def __array__(self, dtype=None):
        raise TypeError(
            f"You are passing {self}, an intermediate Keras symbolic "
            "input/output, to a TF API that does not allow registering custom "
            "dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, "
            "or `tf.map_fn`. Keras Functional model construction only supports "
            "TF API calls that *do* support dispatching, such as `tf.math.add` "
            "or `tf.reshape`. "
            "Other APIs cannot be called directly on symbolic Keras"
            "inputs/outputs. You can work around "
            "this limitation by putting the operation in a custom Keras layer "
            "`call` and calling that layer "
            "on this symbolic input/output."
        )

    @property
    def is_tensor_like(self):
        return True

    def set_shape(self, shape):
        """Updates the shape of this KerasTensor. Mimics
        `tf.Tensor.set_shape()`."""
        if not isinstance(shape, tf.TensorShape):
            shape = tf.TensorShape(shape)
        if not self.shape.is_compatible_with(shape):
            raise ValueError(
                f"Keras symbolic input/output's shape {self.shape} is not "
                f"compatible with supplied shape {shape}."
            )
        else:
            shape = self.shape.merge_with(shape)
            self._type_spec = type_spec_with_shape(self._type_spec, shape)

    def __str__(self):
        symbolic_description = ""
        inferred_value_string = ""
        name_string = ""

        if hasattr(self, "_keras_history"):
            layer = self._keras_history.layer
            symbolic_description = ", description=\"created by layer '%s'\"" % (
                layer.name,
            )
        if self._inferred_value is not None:
            inferred_value_string = ", inferred_value=%s" % self._inferred_value
        if self.name is not None:
            name_string = ", name='%s'" % self._name
        return "KerasTensor(type_spec=%s%s%s%s)" % (
            self.type_spec,
            inferred_value_string,
            name_string,
            symbolic_description,
        )

    def __repr__(self):
        symbolic_description = ""
        inferred_value_string = ""
        if isinstance(self.type_spec, tf.TensorSpec):
            type_spec_string = "shape=%s dtype=%s" % (
                self.shape,
                self.dtype.name,
            )
        else:
            type_spec_string = "type_spec=%s" % self.type_spec

        if hasattr(self, "_keras_history"):
            layer = self._keras_history.layer
            symbolic_description = " (created by layer '%s')" % (layer.name,)
        if self._inferred_value is not None:
            inferred_value_string = " inferred_value=%s" % self._inferred_value
        return "<KerasTensor: %s%s%s>" % (
            type_spec_string,
            inferred_value_string,
            symbolic_description,
        )

    @property
    def dtype(self):
        """Returns the `dtype` symbolically inferred for this Keras output."""
        type_spec = self._type_spec
        if not hasattr(type_spec, "dtype"):
            raise AttributeError(
                f"KerasTensor wraps TypeSpec {type(type_spec).__qualname__}, "
                "which does not have a dtype."
            )
        if not isinstance(type_spec.dtype, tf.DType):
            raise TypeError(
                "KerasTensor requires that wrapped TypeSpec's dtype is a "
                f"DType; got TypeSpec {type(type_spec).__qualname__}, whose "
                "dtype field has unexpected type "
                f"{type(type_spec.dtype).__qualname__}."
            )
        return type_spec.dtype

    def ref(self):
        """Returns a hashable reference object to this KerasTensor.

        The primary use case for this API is to put KerasTensors in a
        set/dictionary. We can't put tensors in a set/dictionary as
        `tensor.__hash__()` is not available and tensor equality (`==`) is
        supposed to produce a tensor representing if the two inputs are equal.

        See the documentation of `tf.Tensor.ref()` for more info.
        """
        return object_identity.Reference(self)

    @property
    def node(self):
        """Find the corresponding `Node` that produce this keras_tensor.

        During functional model construction, Keras will attach `KerasHistory`
        to keras tensor to track the connectivity between calls of layers.
        Return None if there isn't any KerasHistory attached to this tensor.
        """
        if hasattr(self, "_keras_history"):
            layer, node_index, _ = self._keras_history
            return layer.inbound_nodes[node_index]
        return None

    def __iter__(self):
        shape = None
        if self.shape.ndims is not None:
            shape = [dim.value for dim in self.shape.dims]

        if shape is None:
            raise TypeError("Cannot iterate over a Tensor with unknown shape.")
        if not shape:
            raise TypeError("Cannot iterate over a scalar.")
        if shape[0] is None:
            raise TypeError(
                "Cannot iterate over a Tensor with unknown first dimension."
            )
        return _KerasTensorIterator(self, shape[0])

    @property
    def name(self):
        """Returns the (non-unique, optional) name of this symbolic Keras
        value."""
        return self._name

    @classmethod
    def _overload_all_operators(cls, tensor_class):
        """Register overloads for all operators."""
        for operator in tf.Tensor.OVERLOADABLE_OPERATORS:
            cls._overload_operator(tensor_class, operator)

        # We include `experimental_ref` for versions of TensorFlow that
        # still include the deprecated method in Tensors.
        if hasattr(tensor_class, "experimental_ref"):
            cls._overload_operator(tensor_class, "experimental_ref")

    @classmethod
    def _overload_operator(cls, tensor_class, operator):
        """Overload an operator with the same implementation as a base Tensor class.

        We pull the operator out of the class dynamically to avoid ordering
        issues.

        Args:
          tensor_class: The (Composite)Tensor to get the method from.
          operator: string. The operator name.
        """
        tensor_oper = getattr(tensor_class, operator)

        # Compatibility with Python 2:
        # Python 2 unbound methods have type checks for the first arg,
        # so we need to extract the underlying function
        tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)

        setattr(cls, operator, tensor_oper)


KerasTensor._overload_all_operators(tf.Tensor)


class SparseKerasTensor(KerasTensor):
    """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s.

    Specifically, it specializes the conversion to a placeholder in order
    to maintain dense shape information.
    """

    def _to_placeholder(self):
        spec = self.type_spec

        # nest.map_structure loses dense shape information for sparse tensors.
        # So, we special-case sparse placeholder creation.
        # This only preserves shape information for top-level sparse tensors;
        # not for sparse tensors that are nested inside another composite
        # tensor.
        return tf.compat.v1.sparse_placeholder(
            dtype=spec.dtype, shape=spec.shape
        )


class RaggedKerasTensor(KerasTensor):
    """A specialized KerasTensor representation for `tf.RaggedTensor`s.

    Specifically, it:

    1. Specializes the conversion to a placeholder in order
    to maintain shape information for non-ragged dimensions.
    2. Overloads the KerasTensor's operators with the RaggedTensor versions
    when they don't match the `tf.Tensor` versions
    3. Exposes some of the instance method/attribute that are unique to
    the RaggedTensor API (such as ragged_rank).
    """

    def _to_placeholder(self):
        ragged_spec = self.type_spec
        if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None:
            return super()._to_placeholder()

        flat_shape = ragged_spec.shape[ragged_spec.ragged_rank :]
        result = tf.compat.v1.placeholder(ragged_spec.dtype, flat_shape)

        known_num_splits = []
        prod = 1
        for axis_size in ragged_spec.shape:
            if prod is not None:
                if axis_size is None or (
                    getattr(axis_size, "value", True) is None
                ):
                    prod = None
                else:
                    prod = prod * axis_size
            known_num_splits.append(prod)

        for axis in range(ragged_spec.ragged_rank, 0, -1):
            axis_size = ragged_spec.shape[axis]
            if axis_size is None or (getattr(axis_size, "value", True) is None):
                num_splits = known_num_splits[axis - 1]
                if num_splits is not None:
                    num_splits = num_splits + 1
                splits = tf.compat.v1.placeholder(
                    ragged_spec.row_splits_dtype, [num_splits]
                )
                result = tf.RaggedTensor.from_row_splits(
                    result, splits, validate=False
                )
            else:
                rowlen = tf.constant(axis_size, ragged_spec.row_splits_dtype)
                result = tf.RaggedTensor.from_uniform_row_length(
                    result, rowlen, validate=False
                )
        return result

    @property
    def ragged_rank(self):
        return self.type_spec.ragged_rank


# Overload slicing
RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__getitem__")

# Overload math ops
RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__add__")
RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__radd__")
RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__mul__")
RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__rmul__")


# TODO(b/161487382):
# Special-case user-registered symbolic objects (registered by the
# private `register_symbolic_tensor_type` method) by passing them between
# scratch graphs directly.
# This is needed to not break Tensorflow probability
# while they finish migrating to composite tensors.
class UserRegisteredSpec(tf.TypeSpec):
    """TypeSpec to represent user-registered symbolic objects."""

    def __init__(self, shape, dtype):
        self.shape = shape
        self._dtype = dtype
        self.dtype = dtype

    def _component_specs(self):
        raise NotImplementedError

    def _from_components(self, components):
        raise NotImplementedError

    def _serialize(self):
        raise NotImplementedError

    def _to_components(self, value):
        raise NotImplementedError

    def value_type(self):
        raise NotImplementedError


# TODO(b/161487382):
# Special-case user-registered symbolic objects (registered by the
# private `register_symbolic_tensor_type` method) by passing them between
# scratch graphs directly.
# This is needed to not break Tensorflow probability
# while they finish migrating to composite tensors.
class UserRegisteredTypeKerasTensor(KerasTensor):
    """KerasTensor that represents legacy register_symbolic_tensor_type."""

    def __init__(self, user_registered_symbolic_object):
        x = user_registered_symbolic_object
        self._user_registered_symbolic_object = x
        type_spec = UserRegisteredSpec(x.shape, x.dtype)
        name = getattr(x, "name", None)

        super().__init__(type_spec, name)

    @classmethod
    def from_tensor(cls, tensor):
        return cls(tensor)

    @classmethod
    def from_type_spec(cls, type_spec, name=None):
        raise NotImplementedError(
            "You cannot instantiate a KerasTensor "
            "directly from TypeSpec: %s" % type_spec
        )

    def _to_placeholder(self):
        return self._user_registered_symbolic_object


class _KerasTensorIterator:
    """Iterates over the leading dim of a KerasTensor. Performs 0 error
    checks."""

    def __init__(self, tensor, dim0):
        self._tensor = tensor
        self._index = 0
        self._limit = dim0

    def __iter__(self):
        return self

    def __next__(self):
        if self._index == self._limit:
            raise StopIteration
        result = self._tensor[self._index]
        self._index += 1
        return result


# Specify the mappings of tensor class to KerasTensor class.
# This is specifically a list instead of a dict for now because
# 1. we do a check w/ isinstance because a key lookup based on class
#    would miss subclasses
# 2. a list allows us to control lookup ordering
# We include ops.Tensor -> KerasTensor in the first position as a fastpath,
# *and* include object -> KerasTensor at the end as a catch-all.
# We can re-visit these choices in the future as needed.
keras_tensor_classes = [
    (tf.Tensor, KerasTensor),
    (tf.SparseTensor, SparseKerasTensor),
    (tf.RaggedTensor, RaggedKerasTensor),
    (object, KerasTensor),
]


def register_keras_tensor_specialization(cls, keras_tensor_subclass):
    """Register a specialized KerasTensor subclass for a Tensor type."""
    # We always leave (object, KerasTensor) at the end as a generic fallback
    keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass))


def keras_tensor_to_placeholder(x):
    """Construct a graph placeholder to represent a KerasTensor when tracing."""
    if isinstance(x, KerasTensor):
        return x._to_placeholder()
    else:
        return x


def keras_tensor_from_tensor(tensor):
    """Convert a traced (composite)tensor to a representative KerasTensor."""
    # Create a specialized KerasTensor that supports instance methods,
    # operators, and additional value inference if possible
    keras_tensor_cls = None
    for tensor_type, cls in keras_tensor_classes:
        if isinstance(tensor, tensor_type):
            keras_tensor_cls = cls
            break

    out = keras_tensor_cls.from_tensor(tensor)

    if getattr(tensor, "_keras_mask", None) is not None:
        out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask)
    return out


def keras_tensor_from_type_spec(type_spec, name=None):
    """Convert a TypeSpec to a representative KerasTensor."""
    # Create a specialized KerasTensor that supports instance methods,
    # operators, and additional value inference if possible
    keras_tensor_cls = None
    value_type = type_spec.value_type
    for tensor_type, cls in keras_tensor_classes:
        if issubclass(value_type, tensor_type):
            keras_tensor_cls = cls
            break

    return keras_tensor_cls.from_type_spec(type_spec, name=name)


def type_spec_with_shape(spec, shape):
    """Returns a copy of TypeSpec `spec` with its shape set to `shape`."""
    if isinstance(spec, tf.TensorSpec):

        # TODO(b/203201161) Figure out why mutation is needed here, and remove
        # it. (TensorSpec objects should be immutable; and we should not be
        # modifying private fields.)
        shape = tf.TensorShape(shape)
        spec._shape = shape
        return spec
    elif isinstance(spec, tf.RaggedTensorSpec):
        return tf.RaggedTensorSpec(
            shape,
            spec.dtype,
            spec.ragged_rank,
            spec.row_splits_dtype,
            spec.flat_values_spec,
        )
    elif isinstance(spec, tf.SparseTensorSpec):
        return tf.SparseTensorSpec(shape, spec.dtype)
    elif hasattr(spec, "with_shape"):
        # TODO(edloper): Consider adding .with_shape method to TensorSpec,
        # RaggedTensorSpec, and SparseTensorSpec.
        return spec.with_shape(shape)
    else:
        # TODO(edloper): Consider moving this check to the KerasTensor
        # constructor.
        raise ValueError(
            "Keras requires TypeSpec to have a `with_shape` method "
            "that returns a copy of `self` with an updated shape."
        )
