"""Python wrappers around TensorFlow ops.

This file is MACHINE GENERATED! Do not edit.
Original C++ source file: script_ops.cc
"""

import collections

from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes

from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export

from typing import TypeVar

def eager_py_func(input, token, Tout, is_async=False, name=None):
  r"""Eagerly executes a python function to compute func(input)->output. The

  semantics of the input, output, and attributes are the same as those for
  PyFunc.

  Args:
    input: A list of `Tensor` objects.
    token: A `string`.
    Tout: A list of `tf.DTypes`.
    is_async: An optional `bool`. Defaults to `False`.
    name: A name for the operation (optional).

  Returns:
    A list of `Tensor` objects of type `Tout`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "EagerPyFunc", name, input, "token", token, "is_async",
        is_async, "Tout", Tout)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return eager_py_func_eager_fallback(
          input, token=token, is_async=is_async, Tout=Tout, name=name,
          ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  token = _execute.make_str(token, "token")
  if not isinstance(Tout, (list, tuple)):
    raise TypeError(
        "Expected list for 'Tout' argument to "
        "'eager_py_func' Op, not %r." % Tout)
  Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
  if is_async is None:
    is_async = False
  is_async = _execute.make_bool(is_async, "is_async")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "EagerPyFunc", input=input, token=token, Tout=Tout, is_async=is_async,
                       name=name)
  _result = _outputs[:]
  if not _result:
    return _op
  if _execute.must_record_gradient():
    _attrs = ("token", _op.get_attr("token"), "is_async",
              _op._get_attr_bool("is_async"), "Tin", _op.get_attr("Tin"),
              "Tout", _op.get_attr("Tout"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "EagerPyFunc", _inputs_flat, _attrs, _result)
  return _result

EagerPyFunc = tf_export("raw_ops.EagerPyFunc")(_ops.to_raw_op(eager_py_func))


def eager_py_func_eager_fallback(input, token, Tout, is_async, name, ctx):
  token = _execute.make_str(token, "token")
  if not isinstance(Tout, (list, tuple)):
    raise TypeError(
        "Expected list for 'Tout' argument to "
        "'eager_py_func' Op, not %r." % Tout)
  Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
  if is_async is None:
    is_async = False
  is_async = _execute.make_bool(is_async, "is_async")
  _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
  _inputs_flat = list(input)
  _attrs = ("token", token, "is_async", is_async, "Tin", _attr_Tin, "Tout",
  Tout)
  _result = _execute.execute(b"EagerPyFunc", len(Tout), inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "EagerPyFunc", _inputs_flat, _attrs, _result)
  return _result


def py_func(input, token, Tout, name=None):
  r"""Invokes a python function to compute func(input)->output.

  This operation is considered stateful. For a stateless version, see
  PyFuncStateless.

  Args:
    input: A list of `Tensor` objects.
      List of Tensors that will provide input to the Op.
    token: A `string`.
      A token representing a registered python function in this address space.
    Tout: A list of `tf.DTypes`. Data types of the outputs from the op.
      The length of the list specifies the number of outputs.
    name: A name for the operation (optional).

  Returns:
    A list of `Tensor` objects of type `Tout`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "PyFunc", name, input, "token", token, "Tout", Tout)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return py_func_eager_fallback(
          input, token=token, Tout=Tout, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  token = _execute.make_str(token, "token")
  if not isinstance(Tout, (list, tuple)):
    raise TypeError(
        "Expected list for 'Tout' argument to "
        "'py_func' Op, not %r." % Tout)
  Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "PyFunc", input=input, token=token, Tout=Tout, name=name)
  _result = _outputs[:]
  if not _result:
    return _op
  if _execute.must_record_gradient():
    _attrs = ("token", _op.get_attr("token"), "Tin", _op.get_attr("Tin"),
              "Tout", _op.get_attr("Tout"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "PyFunc", _inputs_flat, _attrs, _result)
  return _result

PyFunc = tf_export("raw_ops.PyFunc")(_ops.to_raw_op(py_func))


def py_func_eager_fallback(input, token, Tout, name, ctx):
  token = _execute.make_str(token, "token")
  if not isinstance(Tout, (list, tuple)):
    raise TypeError(
        "Expected list for 'Tout' argument to "
        "'py_func' Op, not %r." % Tout)
  Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
  _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
  _inputs_flat = list(input)
  _attrs = ("token", token, "Tin", _attr_Tin, "Tout", Tout)
  _result = _execute.execute(b"PyFunc", len(Tout), inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "PyFunc", _inputs_flat, _attrs, _result)
  return _result


def py_func_stateless(input, token, Tout, name=None):
  r"""A stateless version of PyFunc.

  Args:
    input: A list of `Tensor` objects.
    token: A `string`.
    Tout: A list of `tf.DTypes`.
    name: A name for the operation (optional).

  Returns:
    A list of `Tensor` objects of type `Tout`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "PyFuncStateless", name, input, "token", token, "Tout", Tout)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return py_func_stateless_eager_fallback(
          input, token=token, Tout=Tout, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  token = _execute.make_str(token, "token")
  if not isinstance(Tout, (list, tuple)):
    raise TypeError(
        "Expected list for 'Tout' argument to "
        "'py_func_stateless' Op, not %r." % Tout)
  Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "PyFuncStateless", input=input, token=token, Tout=Tout, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("token", _op.get_attr("token"), "Tin", _op.get_attr("Tin"),
              "Tout", _op.get_attr("Tout"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "PyFuncStateless", _inputs_flat, _attrs, _result)
  return _result

PyFuncStateless = tf_export("raw_ops.PyFuncStateless")(_ops.to_raw_op(py_func_stateless))


def py_func_stateless_eager_fallback(input, token, Tout, name, ctx):
  token = _execute.make_str(token, "token")
  if not isinstance(Tout, (list, tuple)):
    raise TypeError(
        "Expected list for 'Tout' argument to "
        "'py_func_stateless' Op, not %r." % Tout)
  Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
  _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
  _inputs_flat = list(input)
  _attrs = ("token", token, "Tin", _attr_Tin, "Tout", Tout)
  _result = _execute.execute(b"PyFuncStateless", len(Tout),
                             inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
                             name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "PyFuncStateless", _inputs_flat, _attrs, _result)
  return _result

