"""Python wrappers around TensorFlow ops.

This file is MACHINE GENERATED! Do not edit.
Original C++ source file: collective_ops.cc
"""

import collections

from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes

from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export

from typing import TypeVar

def collective_all_to_all_v3(input, communicator, group_assignment, timeout_seconds=0, name=None):
  r"""Mutually exchanges multiple tensors of identical type and shape.

  Args:
    input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
    communicator: A `Tensor` of type `resource`.
    group_assignment: A `Tensor` of type `int32`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveAllToAllV3", name, input, communicator,
        group_assignment, "timeout_seconds", timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_all_to_all_v3_eager_fallback(
          input, communicator, group_assignment,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveAllToAllV3", input=input, communicator=communicator,
                                group_assignment=group_assignment,
                                timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveAllToAllV3", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveAllToAllV3 = tf_export("raw_ops.CollectiveAllToAllV3")(_ops.to_raw_op(collective_all_to_all_v3))


def collective_all_to_all_v3_eager_fallback(input, communicator, group_assignment, timeout_seconds, name, ctx):
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  communicator = _ops.convert_to_tensor(communicator, _dtypes.resource)
  group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32)
  _inputs_flat = [input, communicator, group_assignment]
  _attrs = ("T", _attr_T, "timeout_seconds", timeout_seconds)
  _result = _execute.execute(b"CollectiveAllToAllV3", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveAllToAllV3", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

_CollectiveAssignGroupV2Output = collections.namedtuple(
    "CollectiveAssignGroupV2",
    ["group_size", "group_key"])


def collective_assign_group_v2(group_assignment, device_index, base_key, name=None):
  r"""Assign group keys based on group assignment.

  Args:
    group_assignment: A `Tensor` of type `int32`.
    device_index: A `Tensor` of type `int32`.
    base_key: A `Tensor` of type `int32`.
    name: A name for the operation (optional).

  Returns:
    A tuple of `Tensor` objects (group_size, group_key).

    group_size: A `Tensor` of type `int32`.
    group_key: A `Tensor` of type `int32`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveAssignGroupV2", name, group_assignment, device_index,
        base_key)
      _result = _CollectiveAssignGroupV2Output._make(_result)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_assign_group_v2_eager_fallback(
          group_assignment, device_index, base_key, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveAssignGroupV2", group_assignment=group_assignment,
                                   device_index=device_index,
                                   base_key=base_key, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ()
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveAssignGroupV2", _inputs_flat, _attrs, _result)
  _result = _CollectiveAssignGroupV2Output._make(_result)
  return _result

CollectiveAssignGroupV2 = tf_export("raw_ops.CollectiveAssignGroupV2")(_ops.to_raw_op(collective_assign_group_v2))


def collective_assign_group_v2_eager_fallback(group_assignment, device_index, base_key, name, ctx):
  group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32)
  device_index = _ops.convert_to_tensor(device_index, _dtypes.int32)
  base_key = _ops.convert_to_tensor(base_key, _dtypes.int32)
  _inputs_flat = [group_assignment, device_index, base_key]
  _attrs = None
  _result = _execute.execute(b"CollectiveAssignGroupV2", 2,
                             inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
                             name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveAssignGroupV2", _inputs_flat, _attrs, _result)
  _result = _CollectiveAssignGroupV2Output._make(_result)
  return _result


def collective_bcast_recv(T, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Receives a tensor value broadcast from another device.

  Args:
    T: A `tf.DType` from: `tf.bool, tf.float32, tf.half, tf.float64, tf.int32, tf.int64`.
    group_size: An `int`.
    group_key: An `int`.
    instance_key: An `int`.
    shape: A `tf.TensorShape` or list of `ints`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `T`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveBcastRecv", name, "T", T, "group_size", group_size,
        "group_key", group_key, "instance_key", instance_key, "shape", shape,
        "communication_hint", communication_hint, "timeout_seconds",
        timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_bcast_recv_eager_fallback(
          T=T, group_size=group_size, group_key=group_key,
          instance_key=instance_key, shape=shape,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  T = _execute.make_type(T, "T")
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  shape = _execute.make_shape(shape, "shape")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveBcastRecv", T=T, group_size=group_size,
                               group_key=group_key, instance_key=instance_key,
                               shape=shape,
                               communication_hint=communication_hint,
                               timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "group_size",
              _op._get_attr_int("group_size"), "group_key",
              _op._get_attr_int("group_key"), "instance_key",
              _op._get_attr_int("instance_key"), "shape",
              _op.get_attr("shape"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveBcastRecv", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveBcastRecv = tf_export("raw_ops.CollectiveBcastRecv")(_ops.to_raw_op(collective_bcast_recv))


def collective_bcast_recv_eager_fallback(T, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx):
  T = _execute.make_type(T, "T")
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  shape = _execute.make_shape(shape, "shape")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _inputs_flat = []
  _attrs = ("T", T, "group_size", group_size, "group_key", group_key,
  "instance_key", instance_key, "shape", shape, "communication_hint",
  communication_hint, "timeout_seconds", timeout_seconds)
  _result = _execute.execute(b"CollectiveBcastRecv", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveBcastRecv", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_bcast_recv_v2(group_size, group_key, instance_key, shape, T, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Receives a tensor value broadcast from another device.

  Args:
    group_size: A `Tensor` of type `int32`.
    group_key: A `Tensor` of type `int32`.
    instance_key: A `Tensor` of type `int32`.
    shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
    T: A `tf.DType` from: `tf.bool, tf.float32, tf.half, tf.float64, tf.int32, tf.int64`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `T`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveBcastRecvV2", name, group_size, group_key,
        instance_key, shape, "T", T, "communication_hint", communication_hint,
        "timeout_seconds", timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_bcast_recv_v2_eager_fallback(
          group_size, group_key, instance_key, shape, T=T,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  T = _execute.make_type(T, "T")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveBcastRecvV2", group_size=group_size, group_key=group_key,
                                 instance_key=instance_key, shape=shape, T=T,
                                 communication_hint=communication_hint,
                                 timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "Tshape",
              _op._get_attr_type("Tshape"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveBcastRecvV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveBcastRecvV2 = tf_export("raw_ops.CollectiveBcastRecvV2")(_ops.to_raw_op(collective_bcast_recv_v2))


def collective_bcast_recv_v2_eager_fallback(group_size, group_key, instance_key, shape, T, communication_hint, timeout_seconds, name, ctx):
  T = _execute.make_type(T, "T")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_Tshape, (shape,) = _execute.args_to_matching_eager([shape], ctx, [_dtypes.int32, _dtypes.int64, ], _dtypes.int32)
  group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
  group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
  instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
  _inputs_flat = [group_size, group_key, instance_key, shape]
  _attrs = ("T", T, "Tshape", _attr_Tshape, "communication_hint",
  communication_hint, "timeout_seconds", timeout_seconds)
  _result = _execute.execute(b"CollectiveBcastRecvV2", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveBcastRecvV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_bcast_send(input, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Broadcasts a tensor value to one or more other devices.

  Args:
    input: A `Tensor`. Must be one of the following types: `bool`, `float32`, `half`, `float64`, `int32`, `int64`.
    group_size: An `int`.
    group_key: An `int`.
    instance_key: An `int`.
    shape: A `tf.TensorShape` or list of `ints`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveBcastSend", name, input, "group_size", group_size,
        "group_key", group_key, "instance_key", instance_key, "shape", shape,
        "communication_hint", communication_hint, "timeout_seconds",
        timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_bcast_send_eager_fallback(
          input, group_size=group_size, group_key=group_key,
          instance_key=instance_key, shape=shape,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  shape = _execute.make_shape(shape, "shape")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveBcastSend", input=input, group_size=group_size,
                               group_key=group_key, instance_key=instance_key,
                               shape=shape,
                               communication_hint=communication_hint,
                               timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "group_size",
              _op._get_attr_int("group_size"), "group_key",
              _op._get_attr_int("group_key"), "instance_key",
              _op._get_attr_int("instance_key"), "shape",
              _op.get_attr("shape"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveBcastSend", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveBcastSend = tf_export("raw_ops.CollectiveBcastSend")(_ops.to_raw_op(collective_bcast_send))


def collective_bcast_send_eager_fallback(input, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx):
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  shape = _execute.make_shape(shape, "shape")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bool, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  _inputs_flat = [input]
  _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key,
  "instance_key", instance_key, "shape", shape, "communication_hint",
  communication_hint, "timeout_seconds", timeout_seconds)
  _result = _execute.execute(b"CollectiveBcastSend", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveBcastSend", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_bcast_send_v2(input, group_size, group_key, instance_key, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Broadcasts a tensor value to one or more other devices.

  Args:
    input: A `Tensor`. Must be one of the following types: `bool`, `float32`, `half`, `float64`, `int32`, `int64`.
    group_size: A `Tensor` of type `int32`.
    group_key: A `Tensor` of type `int32`.
    instance_key: A `Tensor` of type `int32`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveBcastSendV2", name, input, group_size, group_key,
        instance_key, "communication_hint", communication_hint,
        "timeout_seconds", timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_bcast_send_v2_eager_fallback(
          input, group_size, group_key, instance_key,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveBcastSendV2", input=input, group_size=group_size,
                                 group_key=group_key,
                                 instance_key=instance_key,
                                 communication_hint=communication_hint,
                                 timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveBcastSendV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveBcastSendV2 = tf_export("raw_ops.CollectiveBcastSendV2")(_ops.to_raw_op(collective_bcast_send_v2))


def collective_bcast_send_v2_eager_fallback(input, group_size, group_key, instance_key, communication_hint, timeout_seconds, name, ctx):
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bool, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
  group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
  instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
  _inputs_flat = [input, group_size, group_key, instance_key]
  _attrs = ("T", _attr_T, "communication_hint", communication_hint,
  "timeout_seconds", timeout_seconds)
  _result = _execute.execute(b"CollectiveBcastSendV2", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveBcastSendV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_gather(input, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Mutually accumulates multiple tensors of identical type and shape.

  Args:
    input: A `Tensor`. Must be one of the following types: `float32`, `half`, `float64`, `int32`, `int64`.
    group_size: An `int`.
    group_key: An `int`.
    instance_key: An `int`.
    shape: A `tf.TensorShape` or list of `ints`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveGather", name, input, "group_size", group_size,
        "group_key", group_key, "instance_key", instance_key, "shape", shape,
        "communication_hint", communication_hint, "timeout_seconds",
        timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_gather_eager_fallback(
          input, group_size=group_size, group_key=group_key,
          instance_key=instance_key, shape=shape,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  shape = _execute.make_shape(shape, "shape")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveGather", input=input, group_size=group_size,
                            group_key=group_key, instance_key=instance_key,
                            shape=shape,
                            communication_hint=communication_hint,
                            timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "group_size",
              _op._get_attr_int("group_size"), "group_key",
              _op._get_attr_int("group_key"), "instance_key",
              _op._get_attr_int("instance_key"), "shape",
              _op.get_attr("shape"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveGather", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveGather = tf_export("raw_ops.CollectiveGather")(_ops.to_raw_op(collective_gather))


def collective_gather_eager_fallback(input, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx):
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  shape = _execute.make_shape(shape, "shape")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  _inputs_flat = [input]
  _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key,
  "instance_key", instance_key, "shape", shape, "communication_hint",
  communication_hint, "timeout_seconds", timeout_seconds)
  _result = _execute.execute(b"CollectiveGather", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveGather", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_gather_v2(input, group_size, group_key, instance_key, ordering_token, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Mutually accumulates multiple tensors of identical type and shape.

  Args:
    input: A `Tensor`. Must be one of the following types: `float32`, `half`, `float64`, `int32`, `int64`.
    group_size: A `Tensor` of type `int32`.
    group_key: A `Tensor` of type `int32`.
    instance_key: A `Tensor` of type `int32`.
    ordering_token: A list of `Tensor` objects with type `resource`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveGatherV2", name, input, group_size, group_key,
        instance_key, ordering_token, "communication_hint",
        communication_hint, "timeout_seconds", timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_gather_v2_eager_fallback(
          input, group_size, group_key, instance_key, ordering_token,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  if not isinstance(ordering_token, (list, tuple)):
    raise TypeError(
        "Expected list for 'ordering_token' argument to "
        "'collective_gather_v2' Op, not %r." % ordering_token)
  _attr_Nordering_token = len(ordering_token)
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveGatherV2", input=input, group_size=group_size,
                              group_key=group_key, instance_key=instance_key,
                              ordering_token=ordering_token,
                              communication_hint=communication_hint,
                              timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"), "Nordering_token",
              _op._get_attr_int("Nordering_token"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveGatherV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveGatherV2 = tf_export("raw_ops.CollectiveGatherV2")(_ops.to_raw_op(collective_gather_v2))


def collective_gather_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, communication_hint, timeout_seconds, name, ctx):
  if not isinstance(ordering_token, (list, tuple)):
    raise TypeError(
        "Expected list for 'ordering_token' argument to "
        "'collective_gather_v2' Op, not %r." % ordering_token)
  _attr_Nordering_token = len(ordering_token)
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
  group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
  instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
  ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource)
  _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token)
  _attrs = ("T", _attr_T, "communication_hint", communication_hint,
  "timeout_seconds", timeout_seconds, "Nordering_token",
  _attr_Nordering_token)
  _result = _execute.execute(b"CollectiveGatherV2", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveGatherV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_initialize_communicator(group_key, rank, group_size, communication_hint="auto", timeout_seconds=0, name=None):
  r"""Initializes a group for collective operations.

  Args:
    group_key: A `Tensor` of type `int32`.
    rank: A `Tensor` of type `int32`.
    group_size: A `Tensor` of type `int32`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `resource`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveInitializeCommunicator", name, group_key, rank,
        group_size, "communication_hint", communication_hint,
        "timeout_seconds", timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_initialize_communicator_eager_fallback(
          group_key, rank, group_size, communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveInitializeCommunicator", group_key=group_key, rank=rank,
                                            group_size=group_size,
                                            communication_hint=communication_hint,
                                            timeout_seconds=timeout_seconds,
                                            name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("communication_hint", _op.get_attr("communication_hint"),
              "timeout_seconds", _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveInitializeCommunicator", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveInitializeCommunicator = tf_export("raw_ops.CollectiveInitializeCommunicator")(_ops.to_raw_op(collective_initialize_communicator))


def collective_initialize_communicator_eager_fallback(group_key, rank, group_size, communication_hint, timeout_seconds, name, ctx):
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
  rank = _ops.convert_to_tensor(rank, _dtypes.int32)
  group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
  _inputs_flat = [group_key, rank, group_size]
  _attrs = ("communication_hint", communication_hint, "timeout_seconds",
  timeout_seconds)
  _result = _execute.execute(b"CollectiveInitializeCommunicator", 1,
                             inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
                             name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveInitializeCommunicator", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_reduce(input, group_size, group_key, instance_key, merge_op, final_op, subdiv_offsets, wait_for=[], communication_hint="auto", timeout_seconds=0, name=None):
  r"""Mutually reduces multiple tensors of identical type and shape.

  Args:
    input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
    group_size: An `int`.
    group_key: An `int`.
    instance_key: An `int`.
    merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`.
    final_op: A `string` from: `"Id", "Div"`.
    subdiv_offsets: A list of `ints`.
    wait_for: An optional list of `ints`. Defaults to `[]`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveReduce", name, input, "group_size", group_size,
        "group_key", group_key, "instance_key", instance_key, "merge_op",
        merge_op, "final_op", final_op, "subdiv_offsets", subdiv_offsets,
        "wait_for", wait_for, "communication_hint", communication_hint,
        "timeout_seconds", timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_reduce_eager_fallback(
          input, group_size=group_size, group_key=group_key,
          instance_key=instance_key, merge_op=merge_op, final_op=final_op,
          subdiv_offsets=subdiv_offsets, wait_for=wait_for,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  merge_op = _execute.make_str(merge_op, "merge_op")
  final_op = _execute.make_str(final_op, "final_op")
  if not isinstance(subdiv_offsets, (list, tuple)):
    raise TypeError(
        "Expected list for 'subdiv_offsets' argument to "
        "'collective_reduce' Op, not %r." % subdiv_offsets)
  subdiv_offsets = [_execute.make_int(_i, "subdiv_offsets") for _i in subdiv_offsets]
  if wait_for is None:
    wait_for = []
  if not isinstance(wait_for, (list, tuple)):
    raise TypeError(
        "Expected list for 'wait_for' argument to "
        "'collective_reduce' Op, not %r." % wait_for)
  wait_for = [_execute.make_int(_i, "wait_for") for _i in wait_for]
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveReduce", input=input, group_size=group_size,
                            group_key=group_key, instance_key=instance_key,
                            merge_op=merge_op, final_op=final_op,
                            subdiv_offsets=subdiv_offsets, wait_for=wait_for,
                            communication_hint=communication_hint,
                            timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "group_size",
              _op._get_attr_int("group_size"), "group_key",
              _op._get_attr_int("group_key"), "instance_key",
              _op._get_attr_int("instance_key"), "merge_op",
              _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"),
              "subdiv_offsets", _op.get_attr("subdiv_offsets"), "wait_for",
              _op.get_attr("wait_for"), "communication_hint",
              _op.get_attr("communication_hint"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveReduce", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveReduce = tf_export("raw_ops.CollectiveReduce")(_ops.to_raw_op(collective_reduce))


def collective_reduce_eager_fallback(input, group_size, group_key, instance_key, merge_op, final_op, subdiv_offsets, wait_for, communication_hint, timeout_seconds, name, ctx):
  group_size = _execute.make_int(group_size, "group_size")
  group_key = _execute.make_int(group_key, "group_key")
  instance_key = _execute.make_int(instance_key, "instance_key")
  merge_op = _execute.make_str(merge_op, "merge_op")
  final_op = _execute.make_str(final_op, "final_op")
  if not isinstance(subdiv_offsets, (list, tuple)):
    raise TypeError(
        "Expected list for 'subdiv_offsets' argument to "
        "'collective_reduce' Op, not %r." % subdiv_offsets)
  subdiv_offsets = [_execute.make_int(_i, "subdiv_offsets") for _i in subdiv_offsets]
  if wait_for is None:
    wait_for = []
  if not isinstance(wait_for, (list, tuple)):
    raise TypeError(
        "Expected list for 'wait_for' argument to "
        "'collective_reduce' Op, not %r." % wait_for)
  wait_for = [_execute.make_int(_i, "wait_for") for _i in wait_for]
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  _inputs_flat = [input]
  _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key,
  "instance_key", instance_key, "merge_op", merge_op, "final_op", final_op,
  "subdiv_offsets", subdiv_offsets, "wait_for", wait_for,
  "communication_hint", communication_hint, "timeout_seconds",
  timeout_seconds)
  _result = _execute.execute(b"CollectiveReduce", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveReduce", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_reduce_v2(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint="auto", timeout_seconds=0, max_subdivs_per_device=-1, name=None):
  r"""Mutually reduces multiple tensors of identical type and shape.

  Args:
    input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
    group_size: A `Tensor` of type `int32`.
    group_key: A `Tensor` of type `int32`.
    instance_key: A `Tensor` of type `int32`.
    ordering_token: A list of `Tensor` objects with type `resource`.
    merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`.
    final_op: A `string` from: `"Id", "Div"`.
    communication_hint: An optional `string`. Defaults to `"auto"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    max_subdivs_per_device: An optional `int`. Defaults to `-1`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveReduceV2", name, input, group_size, group_key,
        instance_key, ordering_token, "merge_op", merge_op, "final_op",
        final_op, "communication_hint", communication_hint, "timeout_seconds",
        timeout_seconds, "max_subdivs_per_device", max_subdivs_per_device)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_reduce_v2_eager_fallback(
          input, group_size, group_key, instance_key, ordering_token,
          merge_op=merge_op, final_op=final_op,
          communication_hint=communication_hint,
          timeout_seconds=timeout_seconds,
          max_subdivs_per_device=max_subdivs_per_device, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  if not isinstance(ordering_token, (list, tuple)):
    raise TypeError(
        "Expected list for 'ordering_token' argument to "
        "'collective_reduce_v2' Op, not %r." % ordering_token)
  _attr_Nordering_token = len(ordering_token)
  merge_op = _execute.make_str(merge_op, "merge_op")
  final_op = _execute.make_str(final_op, "final_op")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  if max_subdivs_per_device is None:
    max_subdivs_per_device = -1
  max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveReduceV2", input=input, group_size=group_size,
                              group_key=group_key, instance_key=instance_key,
                              ordering_token=ordering_token,
                              merge_op=merge_op, final_op=final_op,
                              communication_hint=communication_hint,
                              timeout_seconds=timeout_seconds,
                              max_subdivs_per_device=max_subdivs_per_device,
                              name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "merge_op",
              _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"),
              "communication_hint", _op.get_attr("communication_hint"),
              "timeout_seconds", _op.get_attr("timeout_seconds"),
              "Nordering_token", _op._get_attr_int("Nordering_token"),
              "max_subdivs_per_device",
              _op._get_attr_int("max_subdivs_per_device"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveReduceV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveReduceV2 = tf_export("raw_ops.CollectiveReduceV2")(_ops.to_raw_op(collective_reduce_v2))


def collective_reduce_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint, timeout_seconds, max_subdivs_per_device, name, ctx):
  if not isinstance(ordering_token, (list, tuple)):
    raise TypeError(
        "Expected list for 'ordering_token' argument to "
        "'collective_reduce_v2' Op, not %r." % ordering_token)
  _attr_Nordering_token = len(ordering_token)
  merge_op = _execute.make_str(merge_op, "merge_op")
  final_op = _execute.make_str(final_op, "final_op")
  if communication_hint is None:
    communication_hint = "auto"
  communication_hint = _execute.make_str(communication_hint, "communication_hint")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  if max_subdivs_per_device is None:
    max_subdivs_per_device = -1
  max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
  group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
  instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
  ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource)
  _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token)
  _attrs = ("T", _attr_T, "merge_op", merge_op, "final_op", final_op,
  "communication_hint", communication_hint, "timeout_seconds",
  timeout_seconds, "Nordering_token", _attr_Nordering_token,
  "max_subdivs_per_device", max_subdivs_per_device)
  _result = _execute.execute(b"CollectiveReduceV2", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveReduceV2", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result


def collective_reduce_v3(input, communicator, group_assignment, reduction, timeout_seconds=0, name=None):
  r"""Mutually reduces multiple tensors of identical type and shape.

  Args:
    input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
    communicator: A `Tensor` of type `resource`.
    group_assignment: A `Tensor` of type `int32`.
    reduction: A `string` from: `"Min", "Max", "Mul", "Add"`.
    timeout_seconds: An optional `float`. Defaults to `0`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `input`.
  """
  _ctx = _context._context or _context.context()
  tld = _ctx._thread_local_data
  if tld.is_eager:
    try:
      _result = pywrap_tfe.TFE_Py_FastPathExecute(
        _ctx, "CollectiveReduceV3", name, input, communicator,
        group_assignment, "reduction", reduction, "timeout_seconds",
        timeout_seconds)
      return _result
    except _core._NotOkStatusException as e:
      _ops.raise_from_not_ok_status(e, name)
    except _core._FallbackException:
      pass
    try:
      return collective_reduce_v3_eager_fallback(
          input, communicator, group_assignment, reduction=reduction,
          timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
    except _core._SymbolicException:
      pass  # Add nodes to the TensorFlow graph.
  # Add nodes to the TensorFlow graph.
  reduction = _execute.make_str(reduction, "reduction")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _, _, _op, _outputs = _op_def_library._apply_op_helper(
        "CollectiveReduceV3", input=input, communicator=communicator,
                              group_assignment=group_assignment,
                              reduction=reduction,
                              timeout_seconds=timeout_seconds, name=name)
  _result = _outputs[:]
  if _execute.must_record_gradient():
    _attrs = ("T", _op._get_attr_type("T"), "reduction",
              _op.get_attr("reduction"), "timeout_seconds",
              _op.get_attr("timeout_seconds"))
    _inputs_flat = _op.inputs
    _execute.record_gradient(
        "CollectiveReduceV3", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

CollectiveReduceV3 = tf_export("raw_ops.CollectiveReduceV3")(_ops.to_raw_op(collective_reduce_v3))


def collective_reduce_v3_eager_fallback(input, communicator, group_assignment, reduction, timeout_seconds, name, ctx):
  reduction = _execute.make_str(reduction, "reduction")
  if timeout_seconds is None:
    timeout_seconds = 0
  timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
  _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
  communicator = _ops.convert_to_tensor(communicator, _dtypes.resource)
  group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32)
  _inputs_flat = [input, communicator, group_assignment]
  _attrs = ("T", _attr_T, "reduction", reduction, "timeout_seconds",
  timeout_seconds)
  _result = _execute.execute(b"CollectiveReduceV3", 1, inputs=_inputs_flat,
                             attrs=_attrs, ctx=ctx, name=name)
  if _execute.must_record_gradient():
    _execute.record_gradient(
        "CollectiveReduceV3", _inputs_flat, _attrs, _result)
  _result, = _result
  return _result

