# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines an input type specification for tf.function."""

import functools
import itertools
import weakref

import numpy as np
import six

from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect


class FunctionSpec(object):
  """Specification of how to bind arguments to a function."""

  @classmethod
  def from_function_and_signature(cls, python_function,
                                  input_signature,
                                  is_pure=False,
                                  experimental_follow_type_hints=False,
                                  jit_compile=None):
    """Creates a FunctionSpec instance given a python function and signature.

    Args:
      python_function: a function to inspect
      input_signature: a signature of the function (None, if variable)
      is_pure: if True all input arguments (including variables and constants)
      will be converted to tensors and no variable changes allowed.
      experimental_follow_type_hints: see `tf.function`
      jit_compile: see `tf.function`

    Returns:
      instance of FunctionSpec
    """
    fullargspec = tf_inspect.getfullargspec(python_function)
    if (input_signature is not None and
        set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ())):
      nodefault_kwonlyargs = set(fullargspec.kwonlyargs)
      if fullargspec.kwonlydefaults is not None:
        nodefault_kwonlyargs -= set(fullargspec.kwonlydefaults)
      raise ValueError("Cannot build TF function from "
                       f"{python_function.__name__}: keyword-only arguments "
                       "must have default values when input_signature is "
                       "provided. Got keyword-only arguments without default "
                       f"values: {sorted(nodefault_kwonlyargs)}.")

    # Checks if the `fullargspec` contains self or cls as its first argument.
    is_method = tf_inspect.isanytargetmethod(python_function)

    # Treat a wrapped partial function as a special case. For all arguments that
    # were overridden with keywords in the partial:
    #   - remove the corresponding arguments,
    #   - remove the corresponding keywords.
    _, unwrapped = tf_decorator.unwrap(python_function)
    if isinstance(unwrapped, functools.partial):
      # Also consider the Python3 case with kwonlydefaults.
      if fullargspec.defaults or fullargspec.kwonlydefaults:
        new_defaults = fullargspec.defaults
        new_args = fullargspec.args
        if fullargspec.defaults:
          # To be able to canonicalize the function properly, we want to ignore
          # default values that are overridden via a partial kwarg. For example:
          #
          #   def func(a, b, c, d=5, e=7):
          #     return a, b, c, d, e
          #   p_func = tf.function(functools.partial(func, 10, e=9))
          #
          # Here we want to drop from the defaults the parameter `e`. If we
          # forwarded the call to the partial function with a default for `e`
          # we would get an error for passing two values for one parameter.
          #
          # Note that this has a limitation: we can only override parameters at
          # the end of the parameter list.
          #
          # In this case we want to end up with 3 arguments (b, c, d) and 1
          # default value (5). We do this by constructing a mask where 0 stands
          # for a value that was overridden by a partial kwarg. The seemingly
          # complicated logic below does just that - for arguments (b, c, d, e)
          # we would get a mask (1, 1, 1, 0).
          old_args = fullargspec.args
          old_defaults = fullargspec.defaults

          no_default = object()
          num_args_without_defaults = len(old_args) - len(old_defaults)
          left_padding = tuple([no_default] * num_args_without_defaults)

          args_with_defaults = zip(old_args, left_padding + old_defaults)

          # Create a mask where 0 stands for args that had a partial kwarg
          # defined.
          non_keyword_defaults_mask = [
              0 if key in unwrapped.keywords else 1 for key in old_args
          ]
          # Keep only arguments and defaults that were not kwargs of partial.
          new_args_with_defaults = list(
              itertools.compress(args_with_defaults, non_keyword_defaults_mask))
          # Keep all args.
          new_args = [arg for arg, _ in new_args_with_defaults]
          # Keep only real default values.
          new_defaults = [
              default for _, default in new_args_with_defaults
              if default is not no_default
          ]
        fullargspec = tf_inspect.FullArgSpec(
            args=new_args,
            varargs=fullargspec.varargs,
            varkw=fullargspec.varkw,
            defaults=new_defaults,
            kwonlyargs=[],
            kwonlydefaults={},
            annotations=fullargspec.annotations)

    # Get the function's name.  Remove functools.partial wrappers if necessary.
    while isinstance(python_function, functools.partial):
      python_function = python_function.func
    name = getattr(python_function, "__name__", "f")

    return FunctionSpec(
        fullargspec,
        is_method,
        input_signature,
        is_pure=is_pure,
        jit_compile=jit_compile,
        experimental_follow_type_hints=experimental_follow_type_hints,
        name=name)

  def __init__(self,
               fullargspec,
               is_method,
               input_signature,
               is_pure=False,
               experimental_follow_type_hints=False,
               name=None,
               jit_compile=None):
    """Constructs a FunctionSpec describing a python function.

    Args:
      fullargspec: `tf_inspect.FullArgSpec` object describing the function.
      is_method: True if the function is a method.
      input_signature: a signature of the function (None, if variable)
      is_pure: if True all input arguments (including variables and constants)
        will be converted to tensors and no variable changes allowed.
      experimental_follow_type_hints: see `tf.function`.
      name: Name of the function
      jit_compile: see `tf.function`.
    """
    self._fullargspec = fullargspec
    self._is_method = is_method
    self._is_pure = is_pure
    self._jit_compile = jit_compile
    self._experimental_follow_type_hints = experimental_follow_type_hints

    # TODO(edloper): Include name when serializing for SavedModel?
    self._name = name or "f"

    if self._is_method:
      # Remove `self`: default arguments shouldn't be matched to it.
      # TODO(b/127938157): Should this error out if there is no arg to
      # be removed?
      args = fullargspec.args[1:]
    else:
      args = fullargspec.args

    # A cache mapping from argument name to index, for canonicalizing
    # arguments that are called in a keyword-like fashion.
    self._args_to_indices = {arg: i for i, arg in enumerate(args)}
    self._arg_names = args

    # A cache mapping from arg index to default value, for canonicalization.
    default_values = fullargspec.defaults
    offset = len(args) - len(default_values or [])
    self._arg_indices_to_default_values = {
        offset + index: default
        for index, default in enumerate(default_values or [])
    }
    self._arg_indices_no_default_values = set(range(len(args))) - set(
        self._arg_indices_to_default_values)
    if input_signature is None:
      self._input_signature = None
    else:
      self._input_signature = tuple(input_signature)
      self._flat_input_signature = tuple(nest.flatten(input_signature,
                                                      expand_composites=True))

  @property
  def fullargspec(self):
    return self._fullargspec

  @property
  def is_method(self):
    return self._is_method

  @property
  def args_to_indices(self):
    return self._args_to_indices

  @property
  def kwargs_to_include(self):
    return self._kwargs_to_include

  @property
  def input_signature(self):
    return self._input_signature

  @property
  def flat_input_signature(self):
    return self._flat_input_signature

  @property
  def is_pure(self):
    return self._is_pure

  @property
  def jit_compile(self):
    return self._jit_compile

  @property
  def arg_names(self):
    return self._arg_names

  @property
  def vararg_name(self):
    return self._fullargspec.varargs

  @property
  def varkw_name(self):
    return self._fullargspec.varkw

  def signature_summary(self, default_values=False):
    """Returns a string summarizing this function's signature.

    Args:
      default_values: If true, then include default values in the signature.

    Returns:
      A `string`.
    """
    args = list(self._arg_names)
    if default_values:
      for (i, default) in self._arg_indices_to_default_values.items():
        args[i] += "={}".format(default)
    if self._fullargspec.kwonlyargs:
      args.append("*")
      for arg_name in self._fullargspec.kwonlyargs:
        args.append(arg_name)
        if default_values and arg_name in self._fullargspec.kwonlydefaults:
          args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
    return f"{self._name}({', '.join(args)})"

  def _convert_annotated_args_to_tensors(self, args, kwargs):
    """Attempts to autobox arguments annotated as tf.Tensor."""
    if self.input_signature is not None:
      return

    args = list(args)
    for i, arg in enumerate(args):
      # See
      # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
      if i < len(self._fullargspec.args):
        annotation_key = self._fullargspec.args[i]
      else:
        annotation_key = self._fullargspec.varargs
      arg_annotation = self._fullargspec.annotations.get(annotation_key, None)

      # TODO(rahulkamat): Change to TensorLike (here ans below)
      if arg_annotation == ops.Tensor:
        args[i] = _to_tensor_or_tensor_spec(arg)

    for kw, v in kwargs.items():
      if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args:
        annotation_key = kw
      else:
        annotation_key = self._fullargspec.varkw
      kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None)
      if kwarg_annotation == ops.Tensor:
        kwargs[kw] = _to_tensor_or_tensor_spec(v)
    return tuple(args), kwargs

  def _validate_inputs(self, flat_inputs):
    """Raises an error if inputs contain illegal values."""
    for inp in flat_inputs:
      # TODO(b/183107079): Allow these once they're handled properly.
      if isinstance(inp, weakref.ref):
        raise ValueError(
            f"weakref input {inp} not supported for function {self._name}")

  def canonicalize_function_inputs(self, *args, **kwargs):
    """Canonicalizes `args` and `kwargs`.

    Canonicalize the inputs to the Python function using a `FunctionSpec`
    instance. In particular, we parse the varargs and kwargs that the
    original function was called with into a tuple corresponding to the
    Python function's positional (named) arguments and a dictionary
    corresponding to its kwargs.  Missing default arguments are added.

    If this `FunctionSpec` has an input signature, then it is used to convert
    arguments to tensors; otherwise, any inputs containing numpy arrays are
    converted to tensors.

    Additionally, any inputs containing numpy arrays are converted to Tensors.

    Args:
      *args: The varargs this object was called with.
      **kwargs: The keyword args this function was called with.

    Returns:
      A canonicalized ordering of the inputs, as well as full and filtered
      (Tensors and Variables only) versions of their concatenated flattened
      representations, represented by a tuple in the form (args, kwargs,
      flat_args, filtered_flat_args). Here: `args` is a full list of bound
      arguments, and `kwargs` contains only true keyword arguments, as opposed
      to named arguments called in a keyword-like fashion.

    Raises:
      ValueError: If a keyword in `kwargs` cannot be matched with a positional
        argument when an input signature is specified, or when the inputs
        do not conform to the input signature.
    """
    if self._is_pure:
      args, kwargs = _convert_variables_to_tensors(args, kwargs)
    if self._experimental_follow_type_hints:
      args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
    # Pre-calculate to reduce overhead
    arglen = len(args)
    if self._input_signature is not None:
      if arglen > len(self._input_signature):
        raise TypeError(f"{self.signature_summary()} specifies "
                        f"{len(self._input_signature)} positional arguments, "
                        f"but got {arglen}.")
      for arg in six.iterkeys(kwargs):
        index = self._args_to_indices.get(arg, None)
        if index is None:
          raise TypeError(f"{self.signature_summary()} got unexpected keyword "
                          f"argument `{arg}`.")
        if index >= len(self._input_signature):
          raise TypeError(
              f"{self.signature_summary()} got keyword argument `{arg}` that "
              "was not included in input_signature.")

    if not kwargs:
      inputs = args
      if self._arg_indices_to_default_values:
        try:
          inputs += tuple(self._arg_indices_to_default_values[i]
                          for i in range(arglen, len(self._arg_names)))
        except KeyError:
          missing_args = [
              self._arg_names[i]
              for i in range(arglen, len(self._arg_names))
              if i not in self._arg_indices_to_default_values
          ]
          raise TypeError(f"{self.signature_summary()} missing required "
                          f"arguments: {', '.join(missing_args)}.")

      if self._fullargspec.kwonlydefaults:
        kwargs.update(self._fullargspec.kwonlydefaults)
    else:
      # Maps from index of arg to its corresponding value, according to `args`
      # and `kwargs`; seeded with the default values for the named args that
      # aren't in `args`.
      arg_indices_to_values = {
          index: default for index, default in six.iteritems(
              self._arg_indices_to_default_values) if index >= arglen
      }
      consumed_args = []
      missing_arg_indices = self._arg_indices_no_default_values - set(
          range(arglen))
      for arg, value in six.iteritems(kwargs):
        index = self._args_to_indices.get(arg, None)
        if index is not None:
          if index < arglen:
            raise TypeError(f"{self.signature_summary()} got two values for "
                            f"{arg!r}.")
          arg_indices_to_values[index] = value
          # These arguments in 'kwargs' might also belong to
          # positional arguments
          missing_arg_indices.discard(index)
          consumed_args.append(arg)
      for arg in consumed_args:
        # After this loop, `kwargs` will only contain keyword_only arguments,
        # and all positional_or_keyword arguments have been moved to `inputs`.
        kwargs.pop(arg)
      inputs = args + _deterministic_dict_values(arg_indices_to_values)
      # Exclude positional args with values
      if missing_arg_indices:
        missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
        if len(missing_args) == 1:
          raise TypeError(f"{self.signature_summary()} missing 1 required "
                          f"argument: {missing_args[0]}.")
        else:
          raise TypeError(f"{self.signature_summary()} missing required "
                          f"arguments: {', '.join(missing_args)}.")

      if kwargs and self._input_signature is not None:
        raise TypeError("Keyword arguments are not supported when "
                        "input_signature is provided. Signature: "
                        f"{self.signature_summary()}. Keyword arguments: "
                        f"{kwargs}.")

      if self._fullargspec.kwonlydefaults:
        for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
          kwargs.setdefault(kwarg, default)

    if self._input_signature is None:
      inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
      kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs)
      flat_inputs += flat_kwargs
      filtered_flat_inputs += filtered_flat_kwargs
    else:
      inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature(
          inputs, self._input_signature, self._flat_input_signature)

    self._validate_inputs(flat_inputs)

    return inputs, kwargs, filtered_flat_inputs


def _to_tensor_or_tensor_spec(x):
  return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else
          ops.convert_to_tensor(x))


def _deterministic_dict_values(dictionary):
  return tuple(dictionary[key] for key in sorted(dictionary))


def _convert_variables_to_tensors(args, kwargs):
  args = [_to_tensor_or_tensor_spec(x) for x in args]
  kwargs = {kw: _to_tensor_or_tensor_spec(x)
            for kw, x in kwargs.items()}
  return tuple(args), kwargs


def _convert_numpy_inputs(inputs):
  """Converts numpy array inputs to tensors."""
  # We assume that any CompositeTensors have already converted their components
  # from numpy arrays to Tensors, so we don't need to expand composites here for
  # the numpy array conversion. Instead, we do so because the flattened inputs
  # are eventually passed to ConcreteFunction()._call_flat, which requires
  # expanded composites.
  flat_inputs = nest.flatten(inputs, expand_composites=True)

  # Check for NumPy arrays in arguments and convert them to Tensors.
  # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
  # finding a way to store them directly in the cache key (currently not
  # possible since ndarrays are not hashable).
  need_packing = False
  filtered_flat_inputs = []
  for index, value in enumerate(flat_inputs):
    if isinstance(value,
                  (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
      filtered_flat_inputs.append(value)
    elif hasattr(value, "__array__") and not (
        hasattr(value, "_should_act_as_resource_variable") or
        isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))):
      # This case is equivalent to _is_ndarray(value) == True
      a = value.__array__()
      if not isinstance(a, np.ndarray):
        raise TypeError(f"The output of __array__ must be an np.ndarray, "
                        f"got {type(a)} from {value}.")
      flat_inputs[index] = constant_op.constant(a)
      filtered_flat_inputs.append(flat_inputs[index])
      need_packing = True
  if need_packing:
    return (nest.pack_sequence_as(
        structure=inputs, flat_sequence=flat_inputs,
        expand_composites=True), flat_inputs, filtered_flat_inputs)
  else:
    return inputs, flat_inputs, filtered_flat_inputs


def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
  """Converts inputs to pass into a function with an explicit signature."""

  def format_error_message(inputs, input_signature):
    return ("  inputs: (\n" + "    " + ",\n    ".join(str(i) for i in inputs) +
            ")\n" + "  input_signature: (\n" + "    " +
            ",\n    ".join(str(i) for i in input_signature) + ")")

  try:
    flatten_inputs = nest.flatten_up_to(
        input_signature,
        inputs[:len(input_signature)],
        expand_composites=True,
        check_types=False)  # lists are convert to tuples for `tf.data`.
  except ValueError:
    raise ValueError("Structure of Python function inputs does not match "
                     "input_signature:\n"
                     f"{format_error_message(inputs, input_signature)}.")

  need_packing = False
  for index, (value, spec) in enumerate(zip(flatten_inputs,
                                            flat_input_signature)):
    if (isinstance(spec, tensor_spec.TensorSpec) and
        not _pywrap_utils.IsTensor(value)):
      try:
        flatten_inputs[index] = ops.convert_to_tensor(
            value, dtype_hint=spec.dtype)
        need_packing = True
      except ValueError:
        raise ValueError("When input_signature is provided, all inputs to "
                         "the Python function must be convertible to "
                         "tensors:\n"
                         f"{format_error_message(inputs, input_signature)}.")

  if any(not spec.is_compatible_with(other) for spec, other in zip(
      flat_input_signature,
      flatten_inputs)):
    raise ValueError("Python inputs incompatible with input_signature:\n"
                     f"{format_error_message(inputs, input_signature)}.")

  if need_packing:
    inputs = nest.pack_sequence_as(
        structure=input_signature,
        flat_sequence=flatten_inputs,
        expand_composites=True)

  flat_inputs = nest.flatten(inputs, expand_composites=True)

  return (inputs, flat_inputs, [
      t for t in flat_inputs
      if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
  ])

