# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras backend config API."""

import tensorflow.compat.v2 as tf

# isort: off
from tensorflow.python.util.tf_export import keras_export

# The type of float to use throughout a session.
_FLOATX = "float32"

# Epsilon fuzz factor used throughout the codebase.
_EPSILON = 1e-7

# Default image data format, one of "channels_last", "channels_first".
_IMAGE_DATA_FORMAT = "channels_last"


@keras_export("keras.backend.epsilon")
@tf.__internal__.dispatch.add_dispatch_support
def epsilon():
    """Returns the value of the fuzz factor used in numeric expressions.

    Returns:
        A float.

    Example:
    >>> tf.keras.backend.epsilon()
    1e-07
    """
    return _EPSILON


@keras_export("keras.backend.set_epsilon")
def set_epsilon(value):
    """Sets the value of the fuzz factor used in numeric expressions.

    Args:
        value: float. New value of epsilon.

    Example:
    >>> tf.keras.backend.epsilon()
    1e-07
    >>> tf.keras.backend.set_epsilon(1e-5)
    >>> tf.keras.backend.epsilon()
    1e-05
     >>> tf.keras.backend.set_epsilon(1e-7)
    """
    global _EPSILON
    _EPSILON = value


@keras_export("keras.backend.floatx")
def floatx():
    """Returns the default float type, as a string.

    E.g. `'float16'`, `'float32'`, `'float64'`.

    Returns:
        String, the current default float type.

    Example:
    >>> tf.keras.backend.floatx()
    'float32'
    """
    return _FLOATX


@keras_export("keras.backend.set_floatx")
def set_floatx(value):
    """Sets the default float type.

    Note: It is not recommended to set this to float16 for training, as this
    will likely cause numeric stability issues. Instead, mixed precision, which
    is using a mix of float16 and float32, can be used by calling
    `tf.keras.mixed_precision.set_global_policy('mixed_float16')`. See the
    [mixed precision guide](
      https://www.tensorflow.org/guide/keras/mixed_precision) for details.

    Args:
        value: String; `'float16'`, `'float32'`, or `'float64'`.

    Example:
    >>> tf.keras.backend.floatx()
    'float32'
    >>> tf.keras.backend.set_floatx('float64')
    >>> tf.keras.backend.floatx()
    'float64'
    >>> tf.keras.backend.set_floatx('float32')

    Raises:
        ValueError: In case of invalid value.
    """
    global _FLOATX
    accepted_dtypes = {"float16", "float32", "float64"}
    if value not in accepted_dtypes:
        raise ValueError(
            f"Unknown `floatx` value: {value}. "
            f"Expected one of {accepted_dtypes}"
        )
    _FLOATX = str(value)


@keras_export("keras.backend.image_data_format")
@tf.__internal__.dispatch.add_dispatch_support
def image_data_format():
    """Returns the default image data format convention.

    Returns:
        A string, either `'channels_first'` or `'channels_last'`

    Example:
    >>> tf.keras.backend.image_data_format()
    'channels_last'
    """
    return _IMAGE_DATA_FORMAT


@keras_export("keras.backend.set_image_data_format")
def set_image_data_format(data_format):
    """Sets the value of the image data format convention.

    Args:
        data_format: string. `'channels_first'` or `'channels_last'`.

    Example:
    >>> tf.keras.backend.image_data_format()
    'channels_last'
    >>> tf.keras.backend.set_image_data_format('channels_first')
    >>> tf.keras.backend.image_data_format()
    'channels_first'
    >>> tf.keras.backend.set_image_data_format('channels_last')

    Raises:
        ValueError: In case of invalid `data_format` value.
    """
    global _IMAGE_DATA_FORMAT
    accepted_formats = {"channels_last", "channels_first"}
    if data_format not in accepted_formats:
        raise ValueError(
            f"Unknown `data_format`: {data_format}. "
            f"Expected one of {accepted_formats}"
        )
    _IMAGE_DATA_FORMAT = str(data_format)
