diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index fae5eb24c52..6d858a8295e 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -51,10 +51,9 @@ py_library(
         "//tensorflow/python:tensor_util",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/eager:executor",
+        "//tensorflow/python/util",
         "//tensorflow/python/util:tf_export",
         "//tensorflow/tools/docs:doc_controls",
-        "@enum34_archive//:enum",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index 37a440bf46e..bdc90ae841b 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -31,6 +31,7 @@ from tensorflow.python.distribute import cross_device_utils
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import distribute_utils
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.distribute import multi_worker_util
@@ -771,6 +772,30 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
         destinations=destinations,
         options=self._communication_options.merge(options))
 
+  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
+    """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
+    # This implementation avoids using `merge_call` and just launches collective
+    # ops in one replica.
+    if options is None:
+      options = collective_util.Options()
+
+    if context.executing_eagerly():
+      # In eager mode, falls back to the default implemenation that uses
+      # `merge_call`. Replica functions are running sequentially in eager mode,
+      # and due to the blocking nature of collective ops, execution will hang if
+      # collective ops are to be launched sequentially.
+      return super()._replica_ctx_all_reduce(reduce_op, value, options)
+
+    replica_context = ds_context.get_replica_context()
+    assert replica_context, (
+        "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
+        "replica context")
+    return self._cross_device_ops._all_reduce(  # pylint: disable=protected-access
+        reduce_op,
+        value,
+        replica_context._replica_id,  # pylint: disable=protected-access
+        options)
+
   def _check_health(self):
     while True:
       if self._check_health_thread_should_stop.is_set():
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 25fd0c94eae..f8f02ae94b7 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 from tensorflow.tools.docs import doc_controls
 
@@ -522,6 +523,37 @@ class CrossDeviceOps(object):
     """
     return simple_broadcast(tensor, destinations, always_mirrored=True)
 
+  # ========================== Collective APIs ================================
+  #
+  # Different than `reduce`, `batch_reduce` and `broadcast` which must be called
+  # in cross-replcia context, collective APIs are to be called in replica
+  # context.
+
+  def _all_reduce(self, reduce_op, value, replica_id, options):
+    """All-reduce the `value` across all replicas so that all get the result.
+
+    `value` can be a nested structure of tensors. The implementation should
+    generally batch the all-reduces when possible. `options` can be set to
+    hint the batching behavior.
+
+    This API must be called in a replica context.
+
+    Args:
+      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
+        be combined. Allows using string representation of the enum such as
+        "SUM", "MEAN".
+      value: Value to be reduced. A tensor or a nested structure of tensors.
+      replica_id: An interger indicating the id of the replica where this
+        all_reduce is called under. This is the local replica id that ranges
+        from 0 to len(local_devices) - 1.
+      options: A `tf.distribute.experimental.CommunicationOptions`.
+
+    Returns:
+      A tensor or a nested strucutre of tensors with the reduced values. The
+      structure is the same as `value`.
+    """
+    raise NotImplementedError("_all_reduce must be implemented in descendants.")
+
 
 @tf_export("distribute.ReductionToOneDevice")
 class ReductionToOneDevice(CrossDeviceOps):
@@ -850,6 +882,8 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
     destinations = dense_values[0]._devices  # pylint: disable=protected-access
     grouped = _group_value_by_device(dense_values)
 
+    # device_grad_packs:
+    # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
     device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
 
     # The actual aggregation of the repacked gradients. Note that they are
@@ -1044,6 +1078,55 @@ class CollectiveAllReduce(CrossDeviceOps):
     # Currently we only support equal number of devices on each worker.
     return self._group_size / len(self._devices)
 
+  def _all_reduce(self, reduce_op, value, replica_id, options):
+    """Implements CrossDeviceOps.all_reduce."""
+    # TODO(b/122840926): reuse this method in _batch_all_reduce.
+    flat_values = nest.flatten(value)
+
+    if isinstance(flat_values[0], ops.IndexedSlices):
+      raise NotImplementedError("all_reduce doesn't support IndexedSlices.")
+
+    batch_size = len(flat_values)
+
+    implementation = options.implementation.value
+    # If NCCL launches can't be ordered (self._limited_nccl == True), we only
+    # use NCCL only when batch_size > 1, hoping that there's only one batched
+    # all-reduce, which is the gradients.
+    if (self._limited_nccl and
+        options.implementation == CommunicationImplementation.NCCL and
+        batch_size == 1):
+      implementation = CommunicationImplementation.AUTO.value
+
+    # Reverse the lists so that there's better chance that values follows
+    # the order in which they are calculated (e.g. when they're gradients), so
+    # as to overlap calculation with communication. However, this may not be
+    # optimal for cases like gradients of complicated non-sequential models.
+    #
+    # Note that we reverse the list before packing so that the first pack won't
+    # be too small, since it's more likely for first few packs to have long
+    # queuing time due to concurrent intense computation.
+    #
+    # TODO(b/147393503): explore solutions for optimal ordering.
+    flat_values.reverse()
+    packs = cross_device_utils.group_by_size(flat_values,
+                                             options.bytes_per_pack)
+
+    launcher = self._launchers[replica_id]
+    if not context.executing_eagerly() and replica_id == 0:
+      logging.info(
+          "Collective all_reduce: %d all-reduces, num_devices = %d, "
+          "group_size = %d, implementation = %s, num_packs = %d", batch_size,
+          len(self._launchers), self._group_size, implementation, len(packs))
+    flat_results = launcher.batch_all_reduce(packs, implementation,
+                                             options.timeout_seconds)
+
+    if reduce_op == reduce_util.ReduceOp.MEAN:
+      for i, v in enumerate(flat_results):
+        flat_results[i] = v / self._group_size
+    flat_results.reverse()
+
+    return nest.pack_sequence_as(value, flat_results)
+
   def reduce_implementation(self, reduce_op, per_replica_value, destinations,
                             options):
     values_util.mark_as_unsaveable()
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 47af7174fb1..d078a66ae81 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -560,6 +560,69 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
       ]
       self.batch_reduce_and_verify(inputs, expect, options)
 
+  @combinations.generate(
+      combinations.combine(
+          num_processes=[1, 2],
+          required_gpus=[0, 1, 2],
+          implementation=[
+              CommunicationImplementation.AUTO,
+              CommunicationImplementation.RING,
+              CommunicationImplementation.NCCL,
+          ],
+          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
+      ))
+  def testCollectiveAllReduce(self, num_processes, required_gpus,
+                              implementation, reduce_op):
+    if (required_gpus == 0 and
+        implementation == CommunicationImplementation.NCCL):
+      self.skipTest("Skip CPU + NCCL combination")
+    if (num_processes == 2 and
+        implementation == CommunicationImplementation.NCCL):
+      self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
+                    "physical GPUs for every process.")
+
+    def replica_fn():
+      collective, devices, _ = self.make_collective(num_processes,
+                                                    required_gpus)
+      options = collective_util.Options(implementation=implementation)
+      group_size = num_processes * (required_gpus or 1)
+
+      @def_function.function
+      def collective_all_reduce():
+        results = []
+        for replica_id, device in enumerate(devices):
+          with ops.device(device):
+            value = constant_op.constant(1.0)
+            results.append(
+                collective._all_reduce(reduce_op, value, replica_id, options))
+        return results
+
+      got = collective_all_reduce()
+      if reduce_op == ReduceOp.SUM:
+        expect = [1.0 * group_size] * len(devices)
+      elif reduce_op == ReduceOp.MEAN:
+        expect = [1.0] * len(devices)
+      self.assertAllClose(got, expect)
+
+      @def_function.function
+      def collective_batch_all_reduce():
+        results = []
+        for replica_id, device in enumerate(devices):
+          with ops.device(device):
+            value = (constant_op.constant(1.0), constant_op.constant(2.0))
+            results.append(
+                collective._all_reduce(reduce_op, value, replica_id, options))
+        return results
+
+      got = collective_batch_all_reduce()
+      if reduce_op == ReduceOp.SUM:
+        expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices)
+      elif reduce_op == ReduceOp.MEAN:
+        expect = [(1.0, 2.0)] * len(devices)
+      self.assertAllClose(got, expect)
+
+    get_global_mpr(num_processes).run(replica_fn)
+
   @combinations.generate(
       combinations.combine(
           num_processes=[1, 2],
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index f5618835ea7..68b9ee1fc79 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -2384,6 +2384,42 @@ class StrategyExtendedV2(object):
         for t, v in value_destination_pairs
     ]
 
+  def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
+    """All-reduce `value` across all replicas so that all get the final result.
+
+    If `value` is a nested structure of tensors, all-reduces of these tensors
+    will be batched when possible. `options` can be set to hint the batching
+    behavior.
+
+    This API must be called in a replica context.
+
+    Args:
+      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
+        be combined. Allows using string representation of the enum such as
+        "SUM", "MEAN".
+      value: Value to be reduced. A tensor or a nested structure of tensors.
+      options: A `tf.distribute.experimental.CommunicationOptions`. Options to
+        perform collective operations. This overrides the default options if the
+        `tf.distribute.Strategy` takes one in the constructor.
+
+    Returns:
+      A tensor or a nested strucutre of tensors with the reduced values. The
+      structure is the same as `value`.
+    """
+    if options is None:
+      options = collective_util.Options()
+    replica_context = distribution_strategy_context.get_replica_context()
+    assert replica_context, (
+        "`StrategyExtended._replica_ctx_all_reduce` must be called in"
+        " a replica context")
+
+    def merge_fn(_, flat_value):
+      return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
+                                  options)
+
+    reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
+    return nest.pack_sequence_as(value, reduced)
+
   def _gather_to(self, value, destinations, axis, options=None):
     """Gather `value` across replicas along axis-th dimension to `destinations`.
 
diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py
index 6b19a744457..d70f2c8c4a6 100644
--- a/tensorflow/python/distribute/strategy_common_test.py
+++ b/tensorflow/python/distribute/strategy_common_test.py
@@ -111,6 +111,58 @@ class ReduceTest(test.TestCase, parameterized.TestCase):
     self.assertEqual(3 * strategy.num_replicas_in_sync, x_s)
 
 
+@combinations.generate(
+    combinations.combine(
+        strategy=[
+            strategy_combinations.multi_worker_mirrored_2x1_cpu,
+            strategy_combinations.multi_worker_mirrored_2x1_gpu,
+            strategy_combinations.tpu_strategy,
+        ] + strategy_combinations.strategies_minus_tpu,
+        tf_function=[combinations.tf_function, combinations.no_tf_function],
+        mode=['eager']))
+class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
+
+  def testBasic(self, strategy, tf_function):
+    if (isinstance(strategy, tpu_strategy.TPUStrategy) and
+        tf_function is combinations.no_tf_function):
+      self.skipTest('Skip TPUStrategy + eager combination.')
+
+    @tf_function
+    def fn():
+
+      def replica_fn():
+        value = constant_op.constant(1.0)
+        reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
+        return reduced
+
+      return strategy.experimental_local_results(strategy.run(replica_fn))
+
+    got = fn()[0]
+    self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
+
+  def testNestedInput(self, strategy, tf_function):
+    # TODO(b/122840926): enable this test once cl/122840926 is submitted.
+    self.skipTest('Enable after cl/353109164 is submitted.')
+
+    if (isinstance(strategy, tpu_strategy.TPUStrategy) and
+        tf_function is combinations.no_tf_function):
+      self.skipTest('Skip TPUStrategy + eager combination.')
+
+    @tf_function
+    def fn():
+
+      def replica_fn():
+        value = (constant_op.constant(1.0), constant_op.constant(2.0))
+        reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
+        return reduced
+
+      return strategy.experimental_local_results(strategy.run(replica_fn))
+
+    got = fn()[0]
+    self.assertEqual(got, (1.0 * strategy.num_replicas_in_sync,
+                           2.0 * strategy.num_replicas_in_sync))
+
+
 def _make_indexed_slices(values, indices, dense_shape):
   tensor = ops.IndexedSlices(
       values=constant_op.constant(values),