# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras initializer serialization / deserialization."""

import threading

from tensorflow.python import tf2
from tensorflow.python.keras.initializers import initializers_v1
from tensorflow.python.keras.initializers import initializers_v2
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_inspect as inspect
from tensorflow.python.ops import init_ops
from tensorflow.python.util.tf_export import keras_export


# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()


def populate_deserializable_objects():
  """Populates dict ALL_OBJECTS with every built-in initializer.
  """
  global LOCAL
  if not hasattr(LOCAL, 'ALL_OBJECTS'):
    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = None

  if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
    # Objects dict is already generated for the proper TF version:
    # do nothing.
    return

  LOCAL.ALL_OBJECTS = {}
  LOCAL.GENERATED_WITH_V2 = tf2.enabled()

  # Compatibility aliases (need to exist in both V1 and V2).
  LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
  LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
  LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
  LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
  LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
  LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
  LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
  LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
  LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
  LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
  LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
  LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
  LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
  LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
  LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros

  # Out of an abundance of caution we also include these aliases that have
  # a non-zero probability of having been included in saved configs in the past.
  LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
  LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
  LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
  LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
  LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
  LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform

  if tf2.enabled():
    # For V2, entries are generated automatically based on the content of
    # initializers_v2.py.
    v2_objs = {}
    base_cls = initializers_v2.Initializer
    generic_utils.populate_dict_with_module_objects(
        v2_objs,
        [initializers_v2],
        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
    for key, value in v2_objs.items():
      LOCAL.ALL_OBJECTS[key] = value
      # Functional aliases.
      LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
  else:
    # V1 initializers.
    v1_objs = {
        'Constant': init_ops.Constant,
        'GlorotNormal': init_ops.GlorotNormal,
        'GlorotUniform': init_ops.GlorotUniform,
        'Identity': init_ops.Identity,
        'Ones': init_ops.Ones,
        'Orthogonal': init_ops.Orthogonal,
        'VarianceScaling': init_ops.VarianceScaling,
        'Zeros': init_ops.Zeros,
        'HeNormal': initializers_v1.HeNormal,
        'HeUniform': initializers_v1.HeUniform,
        'LecunNormal': initializers_v1.LecunNormal,
        'LecunUniform': initializers_v1.LecunUniform,
        'RandomNormal': initializers_v1.RandomNormal,
        'RandomUniform': initializers_v1.RandomUniform,
        'TruncatedNormal': initializers_v1.TruncatedNormal,
    }
    for key, value in v1_objs.items():
      LOCAL.ALL_OBJECTS[key] = value
      # Functional aliases.
      LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value

  # More compatibility aliases.
  LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
  LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
  LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
  LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']


# For backwards compatibility, we populate this file with the objects
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
# using their correct version.
populate_deserializable_objects()
globals().update(LOCAL.ALL_OBJECTS)

# Utility functions


@keras_export('keras.initializers.serialize')
def serialize(initializer):
  return generic_utils.serialize_keras_object(initializer)


@keras_export('keras.initializers.deserialize')
def deserialize(config, custom_objects=None):
  """Return an `Initializer` object from its config."""
  populate_deserializable_objects()
  return generic_utils.deserialize_keras_object(
      config,
      module_objects=LOCAL.ALL_OBJECTS,
      custom_objects=custom_objects,
      printable_module_name='initializer')


@keras_export('keras.initializers.get')
def get(identifier):
  """Retrieve a Keras initializer by the identifier.

  The `identifier` may be the string name of a initializers function or class (
  case-sensitively).

  >>> identifier = 'Ones'
  >>> tf.keras.initializers.deserialize(identifier)
  <...keras.initializers.initializers_v2.Ones...>

  You can also specify `config` of the initializer to this function by passing
  dict containing `class_name` and `config` as an identifier. Also note that the
  `class_name` must map to a `Initializer` class.

  >>> cfg = {'class_name': 'Ones', 'config': {}}
  >>> tf.keras.initializers.deserialize(cfg)
  <...keras.initializers.initializers_v2.Ones...>

  In the case that the `identifier` is a class, this method will return a new
  instance of the class by its constructor.

  Args:
    identifier: String or dict that contains the initializer name or
      configurations.

  Returns:
    Initializer instance base on the input identifier.

  Raises:
    ValueError: If the input identifier is not a supported type or in a bad
      format.
  """

  if identifier is None:
    return None
  if isinstance(identifier, dict):
    return deserialize(identifier)
  elif isinstance(identifier, str):
    identifier = str(identifier)
    return deserialize(identifier)
  elif callable(identifier):
    if inspect.isclass(identifier):
      identifier = identifier()
    return identifier
  else:
    raise ValueError('Could not interpret initializer identifier: ' +
                     str(identifier))
