# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""Contains the base Layer class, from which all layers inherit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import warnings

import tensorflow.compat.v2 as tf

from keras import backend
from keras.engine import base_layer_utils
from keras.engine import base_layer_v1 as base_layer
from keras.legacy_tf_layers import variable_scope_shim
from keras.mixed_precision import policy
from keras.utils import tf_contextlib

# isort: off
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export

_KERAS_STYLE_SCOPE = False


@keras_export(
    v1=["keras.__internal__.legacy.layers.experimental.keras_style_scope"]
)
@tf_export(v1=["layers.experimental.keras_style_scope"])
@tf_contextlib.contextmanager
def keras_style_scope():
    """Use Keras-style variable management.

    All tf.layers and tf RNN cells created in this scope use Keras-style
    variable management.  Creating such layers with a scope= argument is
    disallowed, and reuse=True is disallowed.

    The purpose of this scope is to allow users of existing layers to
    slowly transition to a Keras layers API without breaking existing
    functionality.

    One example of this is when using TensorFlow's RNN classes with Keras
    Models or Networks.  Because Keras models do not properly set variable
    scopes, users of RNNs may either accidentally share scopes between two
    different models, or get errors about variables that already exist.

    Example:

    ```python
    class RNNModel(tf.keras.Model):

      def __init__(self, name):
        super(RNNModel, self).__init__(name=name)
        self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
          [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])

      def call(self, input, state):
        return self.rnn(input, state)

    model_1 = RNNModel("model_1")
    model_2 = RNNModel("model_2")

    # OK
    output_1, next_state_1 = model_1(input, state)
    # Raises an error about trying to create an already existing variable.
    output_2, next_state_2 = model_2(input, state)
    ```

    The solution is to wrap the model construction and execution in a
    keras-style scope:

    ```python
    with keras_style_scope():
      model_1 = RNNModel("model_1")
      model_2 = RNNModel("model_2")

      # model_1 and model_2 are guaranteed to create their own variables.
      output_1, next_state_1 = model_1(input, state)
      output_2, next_state_2 = model_2(input, state)

      assert len(model_1.weights) > 0
      assert len(model_2.weights) > 0
      assert(model_1.weights != model_2.weights)
    ```

    Yields:
      A keras layer style scope.
    """
    global _KERAS_STYLE_SCOPE
    stack = _KERAS_STYLE_SCOPE
    _KERAS_STYLE_SCOPE = True
    try:
        yield
    finally:
        _KERAS_STYLE_SCOPE = stack


@keras_export(
    v1=["keras.__internal__.legacy.layers.experimental.set_keras_style"]
)
@tf_export(v1=["layers.experimental.set_keras_style"])
def set_keras_style():
    """Use Keras-style variable management.

    All tf.layers and tf RNN cells created after keras style ha been enabled
    use Keras-style variable management.  Creating such layers with a
    scope= argument is disallowed, and reuse=True is disallowed.

    The purpose of this function is to allow users of existing layers to
    slowly transition to Keras layers API without breaking existing
    functionality.

    For more details, see the documentation for `keras_style_scope`.

    Note, once keras style has been set, it is set globally for the entire
    program and cannot be unset.

    Example:

    ```python
    set_keras_style()

    model_1 = RNNModel(name="model_1")
    model_2 = RNNModel(name="model_2")

    # model_1 and model_2 are guaranteed to create their own variables.
    output_1, next_state_1 = model_1(input, state)
    output_2, next_state_2 = model_2(input, state)

    assert len(model_1.weights) > 0
    assert len(model_2.weights) > 0
    assert(model_1.weights != model_2.weights)
    ```
    """
    global _KERAS_STYLE_SCOPE
    _KERAS_STYLE_SCOPE = True


def _is_in_keras_style_scope():
    global _KERAS_STYLE_SCOPE
    return _KERAS_STYLE_SCOPE


@keras_export(v1=["keras.__internal__.legacy.layers.Layer"])
@tf_export(v1=["layers.Layer"])
class Layer(base_layer.Layer):
    """Base layer class.

    It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
    instead.

    Args:
      trainable: Boolean, whether the layer's variables should be trainable.
      name: String name of the layer.
      dtype: Default dtype of the layer's weights (default of `None` means use
        the type of the first input).

    Read-only properties:
      name: The name of the layer (string).
      dtype: Default dtype of the layer's weights (default of `None` means use
        the type of the first input).
      trainable_variables: List of trainable variables.
      non_trainable_variables: List of non-trainable variables.
      variables: List of all variables of this layer, trainable and
        non-trainable.
      updates: List of update ops of this layer.
      losses: List of losses added by this layer.
      trainable_weights: List of variables to be included in backprop.
      non_trainable_weights: List of variables that should not be
        included in backprop.
      weights: The concatenation of the lists trainable_weights and
        non_trainable_weights (in this order).

    Mutable properties:
      trainable: Whether the layer should be trained (boolean).
      input_spec: Optional (list of) `InputSpec` object(s) specifying the
        constraints on inputs that can be accepted by the layer.
    """

    def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
        # For backwards compatibility, legacy layers do not use
        # `ResourceVariable` by default.
        self._use_resource_variables = False
        scope = kwargs.pop("_scope", None)
        self._reuse = kwargs.pop("_reuse", None)

        # Avoid an incorrect lint error
        self._trainable_weights = []
        self.built = False

        if dtype is None:
            # Indicates to infer dtype from inputs. When the V2 dtype behavior
            # is enabled, Keras layers default their dtype to floatx instead, so
            # we pass an "_infer" policy to keep the old V1 behavior.
            dtype = policy.Policy("_infer")

        if "autocast" not in kwargs:
            kwargs["autocast"] = False

        # Mark that legacy layers should not be instrumented as Keras usage
        self._disable_keras_instrumentation = True

        super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)

        if _is_in_keras_style_scope():
            if scope is not None:
                raise ValueError(
                    "scope argument not allowed when keras style layers are "
                    "enabled, but saw: {}".format(scope)
                )
            if self._reuse is not None:
                raise ValueError(
                    "reuse argument not allowed when keras style layers are "
                    "enabled, but saw: {}".format(self._reuse)
                )
            self._keras_style = True
        else:
            self._keras_style = False

        self._call_has_scope_arg = "scope" in self._call_spec.arg_names
        if scope:
            with tf.compat.v1.variable_scope(scope) as captured_scope:
                self._scope = captured_scope
        else:
            self._scope = None
        self._current_scope = None

    def apply(self, *args, **kwargs):
        return self(*args, **kwargs)

    # We no longer track graph in tf.layers layers. This property is only kept
    # to maintain API backward compatibility.
    @property
    def graph(self):
        warnings.warn(
            "`Layer.graph` is deprecated and "
            "will be removed in a future version. "
            "Please stop using this property because tf.layers layers no "
            "longer track their graph.",
            stacklevel=2,
        )
        if tf.executing_eagerly():
            raise RuntimeError(
                "Layer.graph not supported when executing eagerly."
            )
        return None

    def _init_set_name(self, name):
        # Determine layer name (non-unique).
        if isinstance(name, tf.compat.v1.VariableScope):
            base_name = name.name
            self._name, _ = self._make_unique_name()
        else:
            base_name = name
            self._name = name
        if not name:
            self._name, base_name = self._make_unique_name()
        self._base_name = base_name

    def _make_unique_name(
        self,
        name_uid_map=None,
        avoid_names=None,
        namespace="",
        zero_based=False,
    ):
        base_name = base_layer.to_snake_case(self.__class__.__name__)
        name = backend.unique_object_name(
            base_name,
            name_uid_map=name_uid_map,
            avoid_names=avoid_names,
            namespace=namespace,
            zero_based=zero_based,
        )
        return (name, base_name)

    @property
    def scope_name(self):
        if not self._scope:
            raise ValueError(
                'No name available for layer scope because the layer "'
                + self._name
                + '" has not been used yet. The scope name '
                + " is determined the first time the layer instance is "
                + "called. You must therefore call the layer before "
                + "querying `scope_name`."
            )
        return self._scope.name

    def add_loss(self, losses, inputs=None):
        previous_losses_length = len(self._losses)
        previous_callable_losses_length = len(self._callable_losses)
        super().add_loss(losses, inputs=inputs)
        if not tf.executing_eagerly():
            # TODO(fchollet): deprecate collection below.
            new_losses = self._losses[previous_losses_length:]
            new_callable_losses = self._callable_losses[
                previous_callable_losses_length:
            ]
            for regularizer in new_callable_losses:
                loss_tensor = regularizer()
                if loss_tensor is not None:
                    new_losses.append(loss_tensor)
            _add_elements_to_collection(
                new_losses, tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES
            )

    def _name_scope(self):
        """Determines op naming for the Layer."""
        if self._keras_style:
            return super()._name_scope()
        return self._current_scope.original_name_scope

    def _set_scope(self, scope=None):
        if self._scope is None:
            # If constructed with _scope=None, lazy setting of scope.
            if self._reuse:
                with tf.compat.v1.variable_scope(
                    scope if scope is not None else self._base_name
                ) as captured_scope:
                    self._scope = captured_scope
            else:
                with tf.compat.v1.variable_scope(
                    scope, default_name=self._base_name
                ) as captured_scope:
                    self._scope = captured_scope

    def add_weight(
        self,
        name,
        shape,
        dtype=None,
        initializer=None,
        regularizer=None,
        trainable=None,
        constraint=None,
        use_resource=None,
        synchronization=tf.VariableSynchronization.AUTO,
        aggregation=tf.compat.v1.VariableAggregation.NONE,
        partitioner=None,
        **kwargs
    ):
        """Adds a new variable to the layer, or gets an existing one; returns it

        Args:
          name: variable name.
          shape: variable shape.
          dtype: The type of the variable. Defaults to `self.dtype` or
            `float32`.
          initializer: initializer instance (callable).
          regularizer: regularizer instance (callable).
          trainable: whether the variable should be part of the layer's
            "trainable_variables" (e.g. variables, biases)
            or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
            Note, if the current variable scope is marked as non-trainable
            then this parameter is ignored and any added variables are also
            marked as non-trainable. `trainable` defaults to `True` unless
            `synchronization` is set to `ON_READ`.
          constraint: constraint instance (callable).
          use_resource: Whether to use `ResourceVariable`.
          synchronization: Indicates when a distributed a variable will be
            aggregated. Accepted values are constants defined in the class
            `tf.VariableSynchronization`. By default the synchronization is set
            to `AUTO` and the current `DistributionStrategy` chooses when to
            synchronize. If `synchronization` is set to `ON_READ`, `trainable`
            must not be set to `True`.
          aggregation: Indicates how a distributed variable will be aggregated.
            Accepted values are constants defined in the class
            `tf.VariableAggregation`.
          partitioner: (optional) partitioner instance (callable).  If
            provided, when the requested variable is created it will be split
            into multiple partitions according to `partitioner`.  In this case,
            an instance of `PartitionedVariable` is returned.  Available
            partitioners include `tf.compat.v1.fixed_size_partitioner` and
            `tf.compat.v1.variable_axis_size_partitioner`.  For more details,
            see the documentation of `tf.compat.v1.get_variable` and the
            "Variable Partitioners and Sharding" section of the API guide.
          **kwargs: Additional keyword arguments.

        Returns:
          The created variable.  Usually either a `Variable` or
          `ResourceVariable` instance.  If `partitioner` is not `None`, a
          `PartitionedVariable` instance is returned.

        Raises:
          RuntimeError: If called with partitioned variable regularization and
            eager execution is enabled.
          ValueError: When trainable has been set to True with synchronization
            set as `ON_READ`.
        """
        for kwarg in kwargs:
            if kwarg != "experimental_autocast":
                raise TypeError("Unknown keyword argument:", kwarg)
        if self._keras_style:
            return super().add_weight(
                name=name,
                shape=shape,
                dtype=dtype,
                initializer=initializer,
                regularizer=regularizer,
                trainable=trainable and self.trainable,
                constraint=constraint,
                use_resource=use_resource,
                synchronization=tf.VariableSynchronization.AUTO,
                aggregation=tf.compat.v1.VariableAggregation.NONE,
                partitioner=partitioner,
                **kwargs
            )

        if synchronization == tf.VariableSynchronization.ON_READ:
            if trainable:
                raise ValueError(
                    "Synchronization value can be set to "
                    "VariableSynchronization.ON_READ only for non-trainable "
                    "variables. You have specified trainable=True and "
                    "synchronization=VariableSynchronization.ON_READ."
                )
            else:
                # Set trainable to be false when variable is to be synced on
                # read.
                trainable = False
        elif trainable is None:
            trainable = True

        def _should_add_regularizer(variable, existing_variable_set):
            if base_layer_utils.is_split_variable(variable):
                for var in variable:
                    if var in existing_variable_set:
                        return False
                return True
            else:
                return variable not in existing_variable_set

        init_graph = None
        if not tf.executing_eagerly():
            default_graph = tf.compat.v1.get_default_graph()
            if default_graph.building_function:
                with tf.init_scope():
                    # Retrieve the variables from the graph into which variables
                    # will be lifted; if initialization ops will be lifted into
                    # the eager context, then there is nothing to retrieve,
                    # since variable collections are not supported when eager
                    # execution is enabled.
                    if not tf.executing_eagerly():
                        init_graph = tf.compat.v1.get_default_graph()
                        existing_variables = set(
                            tf.compat.v1.global_variables()
                        )
            else:
                # Initialization ops will not be lifted out of the default
                # graph.
                init_graph = default_graph
                existing_variables = set(tf.compat.v1.global_variables())

        if dtype is None:
            dtype = self.dtype or tf.float32

        self._set_scope(None)
        reuse = self.built or self._reuse
        prev_len_trainable = len(self._trainable_weights)
        with tf.compat.v1.variable_scope(
            self._scope, reuse=reuse, auxiliary_name_scope=False
        ) as scope:
            self._current_scope = scope
            with backend.name_scope(self._name_scope()):
                use_resource = (
                    use_resource
                    or self._use_resource_variables
                    or scope.use_resource
                )
                if initializer is None:
                    initializer = scope.initializer
                variable = super().add_weight(
                    name,
                    shape,
                    dtype=tf.as_dtype(dtype),
                    initializer=initializer,
                    trainable=trainable and self.trainable,
                    constraint=constraint,
                    partitioner=partitioner,
                    use_resource=use_resource,
                    synchronization=synchronization,
                    aggregation=aggregation,
                    getter=tf.compat.v1.get_variable,
                    **kwargs
                )

                if regularizer:
                    if (
                        tf.compat.v1.executing_eagerly_outside_functions()
                        or _should_add_regularizer(variable, existing_variables)
                    ):
                        self._handle_weight_regularization(
                            name, variable, regularizer
                        )
                        var_store = vs._get_default_variable_store()
                        # When the shim to get variable scope working in TF2 is
                        # used, We need to explicitly make the shim track the
                        # regularization losses as the collections will not be
                        # accessible.
                        if hasattr(var_store, "add_regularizer"):
                            var_store.add_regularizer(variable, regularizer)

                if init_graph is not None:
                    # Handle edge case where a custom getter has overridden
                    # `trainable`.  There is one known occurrence of this, in
                    # unit test testBasicRNNCellNotTrainable in
                    # contrib.rnn.python.kernel_tests.core_rnn_cell_test
                    with init_graph.as_default():
                        trainable_variables = tf.compat.v1.trainable_variables()
                    if (
                        trainable
                        and self.trainable
                        and variable not in trainable_variables
                    ):
                        # A custom getter / variable scope overrode the
                        # trainable flag.
                        extra_trainable_vars = self._trainable_weights[
                            prev_len_trainable:
                        ]
                        self._trainable_weights = self._trainable_weights[
                            :prev_len_trainable
                        ]
                        self._non_trainable_weights += extra_trainable_vars
        return variable

    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

        Args:
          inputs: input tensor(s).
          *args: additional positional arguments to be passed to `self.call`.
          **kwargs: additional keyword arguments to be passed to `self.call`.
            **Note**: kwarg `scope` is reserved for use by the layer.

        Returns:
          Output tensor(s).

        Note:
          - If the layer's `call` method takes a `scope` keyword argument, this
            argument will be automatically set to the current variable scope.
          - If the layer's `call` method takes a `mask` argument (as some Keras
            layers do), its default value will be set to the mask generated
            for `inputs` by the previous layer (if `input` did come from
            a layer that generated a corresponding mask, i.e. if it came from
            a Keras layer with masking support.

        Raises:
          ValueError: if the layer's `call` method returns None (an invalid
            value).
        """
        scope = kwargs.pop("scope", None)

        if self._keras_style:
            if scope is not None:
                raise ValueError(
                    "scope argument not allowed when keras style layers are "
                    "enabled, but saw: {}".format(scope)
                )
            return super().__call__(inputs, *args, **kwargs)

        self._set_scope(scope)

        if self.built:
            try:
                # Some classes which inherit from Layer do not use its
                # constructor, so rather than initializing to None we check for
                # an AttributeError.
                scope_context_manager = self._always_reuse_variable_scope
            except AttributeError:
                scope_context_manager = None

            if scope_context_manager is None:
                # From this point we will always set reuse=True, so create a
                # "final" variable scope with this setting. We avoid re-creating
                # variable scopes after this point as an optimization.
                scope_context_manager = tf.compat.v1.variable_scope(
                    self._scope, reuse=True, auxiliary_name_scope=False
                )

                # Do not cache variable scopes if Eager mode is enabled. If
                # Eager mode is enabled then we don't want to reuse scopes
                # because the cached scope might be from a FuncGraph or Eager
                # scope we are no longer in.
                if not tf.compat.v1.executing_eagerly_outside_functions():
                    self._always_reuse_variable_scope = scope_context_manager
        else:
            scope_context_manager = tf.compat.v1.variable_scope(
                self._scope, reuse=self._reuse, auxiliary_name_scope=False
            )

        with scope_context_manager as scope:
            self._current_scope = scope

            try:
                call_has_scope_arg = self._call_has_scope_arg
            except AttributeError:
                self._call_spec.arg_names = variable_scope_shim.fn_args(
                    self.call
                )
                self._call_has_scope_arg = "scope" in self._call_spec.arg_names
                call_has_scope_arg = self._call_has_scope_arg
            if call_has_scope_arg:
                kwargs["scope"] = scope

            # Actually call layer
            outputs = super().__call__(inputs, *args, **kwargs)

        if not tf.executing_eagerly():
            # Update global default collections.
            _add_elements_to_collection(
                self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
            )
        return outputs

    def __deepcopy__(self, memo):
        no_copy = set(["_graph", "_thread_local", "_metrics_lock"])
        shallow_copy = set(["_scope", "_always_reuse_variable_scope"])
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            if k in no_copy:
                setattr(result, k, v)
            elif k in shallow_copy:
                setattr(result, k, copy.copy(v))
            elif base_layer.is_tensor_or_tensor_list(v):
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

    def __setattr__(self, value, name):
        # By-pass the automatic dependency tracking performed by the parent
        # Layer.
        super(tf.__internal__.tracking.Trackable, self).__setattr__(value, name)

    @property
    def _is_legacy_layer(self):
        """Used by keras to check compatibility. This should not be
        overridden."""
        return True


def _add_elements_to_collection(elements, collection_list):
    if tf.executing_eagerly():
        raise RuntimeError(
            "Using collections from Layers not supported in Eager "
            "mode. Tried to add %s to %s" % (elements, collection_list)
        )
    elements = tf.nest.flatten(elements)
    collection_list = tf.nest.flatten(collection_list)
    for name in collection_list:
        collection = tf.compat.v1.get_collection_ref(name)
        collection_set = {id(e) for e in collection}
        for element in elements:
            if id(element) not in collection_set:
                collection.append(element)
