# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for internal use."""
# pylint: disable=g-direct-tensorflow-import

import inspect
import numbers
import os
import re
import numpy as np

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.ops.numpy_ops import np_export
from tensorflow.python.types import core
from tensorflow.python.util import nest


def _canonicalize_axis(axis, rank):
  return _canonicalize_axes([axis], rank)[0]


def _canonicalize_axes(axes, rank):
  rank = _maybe_static(rank)

  if isinstance(rank, core.Tensor):
    canonicalizer = (
        lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
  else:
    canonicalizer = lambda axis: axis + rank if axis < 0 else axis

  return [canonicalizer(axis) for axis in axes]


def _supports_signature():
  return hasattr(inspect, 'signature')


def _to_tf_type(dtype):
  """Converts a native python or numpy type to TF DType.

  Args:
    dtype: Could be a python type, a numpy type or a TF DType.

  Returns:
    A tensorflow `DType`.
  """
  return dtypes.as_dtype(dtype)


def _to_numpy_type(dtype):
  """Converts a native python or TF DType to numpy type.

  Args:
    dtype: Could be a python type, a numpy type or a TF DType.

  Returns:
    A NumPy `dtype`.
  """
  if isinstance(dtype, dtypes.DType):
    return dtype.as_numpy_dtype
  return np.dtype(dtype)


def isscalar(val):
  """Returns whether `val` is a scalar value or scalar Tensor."""
  if isinstance(val, np_arrays.ndarray):
    val = val.data
  if isinstance(val, core.Tensor):
    ndims = val.shape.ndims
    if ndims is not None:
      return ndims == 0
    else:
      return math_ops.equal(array_ops.rank(val), 0)
  else:
    return np.isscalar(val)


def _has_docstring(f):
  return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
          f.__doc__)


def _add_blank_line(s):
  if s.endswith('\n'):
    return s + '\n'
  else:
    return s + '\n\n'


def _np_signature(f):
  """An enhanced inspect.signature that can handle numpy.ufunc."""
  # TODO(wangpeng): consider migrating away from inspect.signature.
  # inspect.signature is supported in Python 3.3.
  if not hasattr(inspect, 'signature'):
    return None
  if f is None:
    return None
  if not isinstance(f, np.ufunc):
    try:
      return inspect.signature(f)
    except ValueError:
      return None

  def names_from_num(prefix, n):
    if n <= 0:
      return []
    elif n == 1:
      return [prefix]
    else:
      return [prefix + str(i + 1) for i in range(n)]

  input_names = names_from_num('x', f.nin)
  output_names = names_from_num('out', f.nout)
  keyword_only_params = [('where', True), ('casting', 'same_kind'),
                         ('order', 'K'), ('dtype', None), ('subok', True),
                         ('signature', None), ('extobj', None)]
  params = []
  params += [
      inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
      for name in input_names
  ]
  if f.nout > 1:
    params += [
        inspect.Parameter(
            name, inspect.Parameter.POSITIONAL_ONLY, default=None)
        for name in output_names
    ]
  params += [
      inspect.Parameter(
          'out',
          inspect.Parameter.POSITIONAL_OR_KEYWORD,
          default=None if f.nout == 1 else (None,) * f.nout)
  ]
  params += [
      inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
      for name, default in keyword_only_params
  ]
  return inspect.Signature(params)


# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
# allow positional-only argument. So we conflate positional-only, keyword-only
# and positional-or-keyword arguments here.
def _is_compatible_param_kind(a, b):

  def relax(k):
    if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
      return inspect.Parameter.POSITIONAL_OR_KEYWORD
    return k

  return relax(a) == relax(b)


def _prepare_np_fun_name_and_fun(np_fun_name, np_fun):
  """Mutually propagates information between `np_fun_name` and `np_fun`.

  If one is None and the other is not, we'll try to make the former not None in
  a best effort.

  Args:
    np_fun_name: name for the np_fun symbol. At least one of np_fun or
      np_fun_name shoud be set.
    np_fun: the numpy function whose docstring will be used.

  Returns:
    Processed `np_fun_name` and `np_fun`.
  """
  if np_fun_name is not None:
    assert isinstance(np_fun_name, str)
  if np_fun is not None:
    assert not isinstance(np_fun, str)
  if np_fun is None:
    assert np_fun_name is not None
    try:
      np_fun = getattr(np, str(np_fun_name))
    except AttributeError:
      np_fun = None
  if np_fun_name is None:
    assert np_fun is not None
    np_fun_name = np_fun.__name__
  return np_fun_name, np_fun


def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None,
                   link=None):
  """Helper to get docs."""
  assert np_f or np_fun_name
  if not np_fun_name:
    np_fun_name = np_f.__name__
  doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name
  if unsupported_params:
    doc += 'Unsupported arguments: ' + ', '.join(
        '`' + name + '`' for name in unsupported_params) + '.\n\n'
  if _has_docstring(f):
    doc += f.__doc__
    doc = _add_blank_line(doc)
  # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
  #   doc according to some global switch.
  doc = _add_np_doc(doc, np_fun_name, np_f, link=link)
  return doc


_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')


def get_np_doc_form():
  """Gets the form of the original numpy docstrings.

  Returns:
    See `set_np_doc_form` for the list of valid values.
  """
  return _np_doc_form


def set_np_doc_form(value):
  r"""Selects the form of the original numpy docstrings.

  This function sets a global variable that controls how a tf-numpy symbol's
  docstring should refer to the original numpy docstring. If `value` is
  `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
  docstring. Otherwise, a link to the original numpy docstring will be
  added. Which numpy version the link points to depends on `value`:
  * `'stable'`: the current stable version;
  * `'dev'`: the current development version;
  * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
    e.g. '1.16'.

  Args:
    value: the value to set the global variable to.
  """
  global _np_doc_form
  _np_doc_form = value


class Link:

  def __init__(self, v):
    self.value = v


class AliasOf:

  def __init__(self, v):
    self.value = v


class NoLink:
  pass


def generate_link(flag, np_fun_name):
  """Generates link from numpy function name.

  Args:
    flag: the flag to control link form. See `set_np_doc_form`.
    np_fun_name: the numpy function name.

  Returns:
    A string.
  """
  # Only adds link in this case
  if flag == 'dev':
    template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
  elif flag == 'stable':
    template = (
        'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
  elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
    # `flag` is the version number
    template = ('https://numpy.org/doc/' + flag +
                '/reference/generated/numpy.%s.html')
  else:
    return None
  return template % np_fun_name


_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in
                  ('True', 'true', '1'))


def is_check_link():
  return _is_check_link


def set_check_link(value):
  global _is_check_link
  _is_check_link = value


def _add_np_doc(doc, np_fun_name, np_f, link):
  """Appends the numpy docstring to `doc`, according to `set_np_doc_form`.

  See `set_np_doc_form` for how it controls the form of the numpy docstring.

  Args:
    doc: the docstring to be appended to.
    np_fun_name: the name of the numpy function.
    np_f: (optional) the numpy function.
    link: (optional) which link to use. See `np_doc` for details.

  Returns:
    `doc` with numpy docstring appended.
  """
  flag = get_np_doc_form()
  if flag == 'inlined':
    if _has_docstring(np_f):
      doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
      # TODO(wangpeng): It looks like code snippets in numpy doc don't work
      # correctly with doctest. Fix that and remove the reformatting of the np_f
      # comment.
      doc += np_f.__doc__.replace('>>>', '>')
  elif isinstance(flag, str):
    if link is None:
      url = generate_link(flag, np_fun_name)
    elif isinstance(link, AliasOf):
      url = generate_link(flag, link.value)
    elif isinstance(link, Link):
      url = link.value
    else:
      url = None
    if url is not None:
      if is_check_link():
        # Imports locally because some builds may not have `requests`
        import requests  # pylint: disable=g-import-not-at-top
        r = requests.head(url)
        if r.status_code != 200:
          raise ValueError(
              f'Check link failed at [{url}] with status code {r.status_code}. '
              f'Argument `np_fun_name` is {np_fun_name}.')
      doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % (
          np_fun_name, url)
  return doc


_is_sig_mismatch_an_error = (
    os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1'))


def is_sig_mismatch_an_error():
  return _is_sig_mismatch_an_error


def set_is_sig_mismatch_an_error(value):
  global _is_sig_mismatch_an_error
  _is_sig_mismatch_an_error = value


def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None,
           link=None):
  """Attachs numpy docstring to a function.

  Args:
    np_fun_name: name for the np_fun symbol. At least one of np_fun or
      np_fun_name shoud be set.
    np_fun: (optional) the numpy function whose docstring will be used.
    export: whether to export this symbol under module
      `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be
      a function directly under the `numpy` module, not under any submodule of
      `numpy` (e.g. `numpy.random`).
    unsupported_params: (optional) the list of parameters not supported
      by tf.numpy.
    link: (optional) which link to use. If `None`, a default link generated from
      `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will
      be used in place of `np_fun_name` for the link generation. If an instance
      of `Link`, `link.value` will be used as the whole link. If an instance of
      `NoLink`, no link will be added.

  Returns:
    A function decorator that attaches the docstring from `np_fun` to the
    decorated function.
  """
  np_fun_name_orig, np_fun_orig = np_fun_name, np_fun
  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
  np_sig = _np_signature(np_fun)
  if unsupported_params is None:
    unsupported_params = []

  def decorator(f):
    """The decorator."""
    if hasattr(inspect, 'signature') and np_sig is not None:
      try:
        sig = inspect.signature(f)
      except ValueError:
        sig = None
      if sig is not None:
        for name, param in sig.parameters.items():
          np_param = np_sig.parameters.get(name)
          if np_param is None:
            if is_sig_mismatch_an_error():
              raise TypeError(
                  f'Cannot find parameter {name} in the numpy function\'s '
                  f'signature (which has these parameters: '
                  f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` '
                  f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
            else:
              continue
          if (is_sig_mismatch_an_error() and
              not _is_compatible_param_kind(param.kind, np_param.kind)):
            raise TypeError(
                f'Parameter {name} is of kind {param.kind} while in numpy it '
                f'is of kind {np_param.kind}. Argument `np_fun_name` is '
                f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
          has_default = (param.default != inspect.Parameter.empty)
          np_has_default = (np_param.default != inspect.Parameter.empty)
          if is_sig_mismatch_an_error() and has_default != np_has_default:
            raise TypeError(
                'Parameter {} should{} have a default value. Argument '
                '`np_fun_name` is {}. Argument `np_fun` is {}.'.format(
                    name, '' if np_has_default else ' not', np_fun_name_orig,
                    np_fun_orig))
        for name in np_sig.parameters:
          if name not in sig.parameters:
            unsupported_params.append(name)
    f.__doc__ = _np_doc_helper(
        f, np_fun, np_fun_name=np_fun_name,
        unsupported_params=unsupported_params, link=link)
    if export:
      return np_export.np_export(np_fun_name)(f)
    else:
      return f

  return decorator


def np_doc_only(np_fun_name, np_fun=None, export=True):
  """Attachs numpy docstring to a function.

  This differs from np_doc in that it doesn't check for a match in signature.

  Args:
    np_fun_name: name for the np_fun symbol. At least one of np_fun or
      np_fun_name shoud be set.
    np_fun: (optional) the numpy function whose docstring will be used.
    export: whether to export this symbol under module
      `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a
      function directly under the `numpy` module, not under any submodule of
      `numpy` (e.g. `numpy.random`).

  Returns:
    A function decorator that attaches the docstring from `np_fun` to the
    decorated function.
  """
  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)

  def decorator(f):
    f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name)
    if export:
      return np_export.np_export(np_fun_name)(f)
    else:
      return f

  return decorator


# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
@np_doc('finfo')
def finfo(dtype):
  """Note that currently it just forwards to the numpy namesake, while
  tensorflow and numpy dtypes may have different properties."""
  return np.finfo(_to_numpy_type(dtype))
# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args


def _maybe_get_dtype(x):
  """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
  # Don't put np.ndarray in this list, because np.result_type looks at the
  # value (not just dtype) of np.ndarray to decide the result type.
  if isinstance(x, numbers.Real):
    return x
  if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):
    return _to_numpy_type(x.dtype)
  if isinstance(x, dtypes.DType):
    return x.as_numpy_dtype
  if isinstance(x, (list, tuple)):
    raise ValueError(
        f'Cannot find dtype for type inference from argument `x` of a sequence '
        f'type {type(x)}. For sequences, please call this function on each '
        f'element individually.')
  return x


# Can't use np_doc because np.result_type is a builtin function.
@np_doc_only('result_type')
def result_type(*arrays_and_dtypes):  # pylint: disable=missing-function-docstring
  arrays_and_dtypes = [
      _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
  ]
  if not arrays_and_dtypes:
    # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
    arrays_and_dtypes = [np.asarray([])]
  return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access


def result_type_unary(a, dtype):  # pylint: disable=missing-function-docstring
  """Find the result type from a single input and a dtype."""
  if dtype:
    # We need to let np_utils.result_type decide the dtype, not tf.zeros_like
    return result_type(dtype)

  # np_utils.result_type treats string inputs as dtype strings, not as strings.
  # but for unary we want to treat it as a string input.
  if isinstance(a, str):
    return np.unicode_
  elif isinstance(a, bytes):
    return np.bytes_

  # TF and numpy has different interpretations of Python types such as
  # `float`, so we let `np_utils.result_type` decide.
  return result_type(a)


def _result_type_binary(t1, t2):  # pylint: disable=missing-function-docstring
  """A specialization of result_type for 2 arguments for performance reasons."""
  try:
    return np_dtypes._result_type(_maybe_get_dtype(t1),  # pylint: disable=protected-access
                                  _maybe_get_dtype(t2))  # pylint: disable=protected-access
  except ValueError:
    return result_type(t1, t2)


@np_doc('promote_types')
def promote_types(type1, type2):  # pylint: disable=missing-function-docstring
  type1 = _to_numpy_type(type1)
  type2 = _to_numpy_type(type2)
  return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))


def tf_broadcast(*args):
  """Broadcast tensors.

  Args:
    *args: a list of tensors whose shapes are broadcastable against each other.

  Returns:
    Tensors broadcasted to the common shape.
  """
  if len(args) <= 1:
    return args
  sh = array_ops.shape(args[0])
  for arg in args[1:]:
    sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
  return [array_ops.broadcast_to(arg, sh) for arg in args]


# TODO(wangpeng): Move the following functions to a separate file and check for
#   float dtypes in each of them.


def get_static_value(x):
  """A version of tf.get_static_value that returns None on float dtypes.

  It returns None on float dtypes in order to avoid breaking gradients.

  Args:
    x: a tensor.

  Returns:
    Same as `tf.get_static_value`, except that it returns None when `x` has a
    float dtype.
  """
  if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
    return None
  return tensor_util.constant_value(x)


def _maybe_static(x):
  value = get_static_value(x)
  if value is None:
    return x
  else:
    return value


# All the following functions exist becaues get_static_value can't handle
# their TF counterparts.


def cond(pred, true_fn, false_fn):
  """A version of tf.cond that tries to evaluate the condition."""
  v = get_static_value(pred)
  if v is None:
    return control_flow_ops.cond(pred, true_fn, false_fn)
  if v:
    return true_fn()
  else:
    return false_fn()


def add(a, b):
  """A version of tf.add that eagerly evaluates if possible."""
  return _maybe_static(a) + _maybe_static(b)


def subtract(a, b):
  """A version of tf.subtract that eagerly evaluates if possible."""
  return _maybe_static(a) - _maybe_static(b)


def greater(a, b):
  """A version of tf.greater that eagerly evaluates if possible."""
  return _maybe_static(a) > _maybe_static(b)


def greater_equal(a, b):
  """A version of tf.greater_equal that eagerly evaluates if possible."""
  return _maybe_static(a) >= _maybe_static(b)


def less_equal(a, b):
  """A version of tf.less_equal that eagerly evaluates if possible."""
  return _maybe_static(a) <= _maybe_static(b)


def logical_and(a, b):
  """A version of tf.logical_and that eagerly evaluates if possible."""
  a_value = get_static_value(a)
  if a_value is not None:
    if np.isscalar(a_value):
      if a_value:
        return _maybe_static(b)
      else:
        return a_value
    else:
      return a_value & _maybe_static(b)
  else:
    return a & _maybe_static(b)


def logical_or(a, b):
  """A version of tf.logical_or that eagerly evaluates if possible."""
  a_value = get_static_value(a)
  if a_value is not None:
    if np.isscalar(a_value):
      if a_value:
        return a_value
      else:
        return _maybe_static(b)
    else:
      return a_value | _maybe_static(b)
  else:
    return a | _maybe_static(b)


def getitem(a, slice_spec):
  """A version of __getitem__ that eagerly evaluates if possible."""
  return _maybe_static(a)[slice_spec]


def reduce_all(input_tensor, axis=None, keepdims=False):
  """A version of tf.reduce_all that eagerly evaluates if possible."""
  v = get_static_value(input_tensor)
  if v is None:
    return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
  else:
    return v.all(axis=axis, keepdims=keepdims)


def reduce_any(input_tensor, axis=None, keepdims=False):
  """A version of tf.reduce_any that eagerly evaluates if possible."""
  v = get_static_value(input_tensor)
  if v is None:
    return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
  else:
    return v.any(axis=axis, keepdims=keepdims)


def tf_rank(t):
  r = t.shape.rank
  if r is not None:
    return r
  return array_ops.rank(t)
