diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 8e19c9587fa..fe48b3f6079 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -42,6 +43,14 @@ limitations under the License.
 #define VALUE_IN_DEBUG_STRING false
 
 namespace tensorflow {
+
+namespace {
+bool IsCancelled(CancellationManager* cancel_mgr) {
+  return cancel_mgr != nullptr &&
+         (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
+}
+}  // namespace
+
 /*static*/
 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
                                           int64 num_chunks) {
@@ -215,14 +224,12 @@ CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
 
 void BaseCollectiveExecutor::StartAbort(const Status& s) {
-  VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
   Status status;
   {
     mutex_lock l(status_mu_);
     if (!status_.ok()) {
-      LOG(WARNING)
-          << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
-          << s;
+      VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
+              << s;
       return;
     }
     status_ = StatusGroup::MakeDerived(Status(
@@ -233,6 +240,7 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) {
             "program to reset.")));
     status = status_;
   }
+  LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
   cem_->GetParamResolver()->StartAbort(status);
   remote_access_->StartAbort(status);
   if (cem_->GetNcclCommunicator() != nullptr) {
@@ -261,9 +269,14 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
                                           StatusCallback done) {
   // See CompleteParamsAsync() how done() and the timeout callback interacts.
   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
-  auto done_safe = [this, done, is_callback_called](const Status& s) {
+  auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
     bool called = is_callback_called->exchange(true);
     if (!called) {
+      if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
+        // This is a collective error. Abort CollectiveExecutor so that this
+        // error can propagate to other workers.
+        StartAbort(s);
+      }
       done(GetStatus(s));
     }
   };
@@ -341,9 +354,15 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
   // timeout callback executes, done_safe will become a no-op and the timeout
   // callback is responsible for invoking done() at the end.
   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
-  auto done_safe = [this, is_callback_called, done](const Status& s) {
+  auto done_safe = [this, is_callback_called, cancel_mgr,
+                    done](const Status& s) {
     bool called = is_callback_called->exchange(true);
     if (!called) {
+      if (!s.ok() && !IsCancelled(cancel_mgr)) {
+        // This is a collective error. Abort CollectiveExecutor so that this
+        // error can propagate to other workers.
+        StartAbort(s);
+      }
       done(GetStatus(s));
     }
   };
diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc
index 91af06cf352..e664eb90865 100644
--- a/tensorflow/core/common_runtime/ring_alg.cc
+++ b/tensorflow/core/common_runtime/ring_alg.cc
@@ -278,12 +278,17 @@ void RingAlg::StartAbort(const Status& s) {
       status_.Update(s);
     }
   }
-  // If this is the initial entry to abort mode then invoke StartAbort
-  // on the CollectiveExecutor that invoked us.  That should start
-  // cancellation on all of the outstanding CollectiveRemoteAccess
-  // actions.
+  // If this is the initial entry to abort mode and it's not a cancellation,
+  // then invoke StartAbort on the CollectiveExecutor that invoked us.  That
+  // should start cancellation on all of the outstanding CollectiveRemoteAccess
+  // actions. If it's cancellation all pending send/recv should be cancelled as
+  // well and there's then no need to abort.
   if (abort_started) {
-    col_ctx_->col_exec->StartAbort(s);
+    if (col_ctx_->op_ctx->cancellation_manager() == nullptr ||
+        (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() &&
+         !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) {
+      col_ctx_->col_exec->StartAbort(s);
+    }
   }
 }
 
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 9cccae51b15..357ae158ea1 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -49,9 +49,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
   return k;
 }
 
-class CollectiveOpKernel : public AsyncOpKernel {
+class CollectiveOpV1Kernel : public AsyncOpKernel {
  public:
-  explicit CollectiveOpKernel(OpKernelConstruction* c)
+  explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
       : AsyncOpKernel(c), name_(name()) {}
 
   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
@@ -79,29 +79,11 @@ class CollectiveOpKernel : public AsyncOpKernel {
       // don't need to block on the deregistration. Also StartAbort() may call
       // done() and DeregisterCallback may deadlock.
       c->cancellation_manager()->TryDeregisterCallback(token);
-      // Abort CollectiveExecutor so that this error can propagate to other
-      // workers.
-      if (!c->status().ok()) {
-        col_exec->StartAbort(c->status());
-      }
       done();
     };
     ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
   }
 
- protected:
-  virtual void ComputeAsyncImpl(OpKernelContext* c,
-                                CollectiveExecutor* col_exec,
-                                DoneCallback done) = 0;
-
-  string name_;
-};
-
-class CollectiveOpV1Kernel : public CollectiveOpKernel {
- public:
-  explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
-      : CollectiveOpKernel(c) {}
-
   // A string encoding instance, frame and iter to be handed off to
   // the implementation for use in generating RecvBuf keys.
   string GetCollectiveKey(OpKernelContext* c) {
@@ -140,6 +122,11 @@ class CollectiveOpV1Kernel : public CollectiveOpKernel {
   }
 
  protected:
+  virtual void ComputeAsyncImpl(OpKernelContext* c,
+                                CollectiveExecutor* col_exec,
+                                DoneCallback done) = 0;
+
+  string name_;
   CollectiveParams col_params_;
   std::vector<int32> dependencies_;
 };
@@ -470,10 +457,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
                         CollectiveBcastRecvOpKernel);
 
-class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
+class CollectiveReduceV2OpKernel : public AsyncOpKernel {
  public:
   explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
-      : CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
+      : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
     OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
     string merge_op_name;
     OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
@@ -504,9 +491,14 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
             << " communication_hint " << communication_hint_;
   }
 
- protected:
-  void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
-                        DoneCallback done) override {
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            name_),
+        done);
     const Tensor& input = c->input(0);
     const Tensor& group_size = c->input(1);
     const Tensor& group_key = c->input(2);
@@ -597,6 +589,7 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
   }
 
  private:
+  string name_;
   DataType data_type_ = DT_INVALID;
   string communication_hint_;
   float timeout_seconds_ = 0;
@@ -614,10 +607,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
                             .HostMemory("instance_key"),
                         CollectiveReduceV2OpKernel);
 
-class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
+class CollectiveGatherV2OpKernel : public AsyncOpKernel {
  public:
   explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
-      : CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
+      : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
     OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
     OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
     OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
@@ -627,9 +620,14 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
             << " communication_hint " << communication_hint_;
   }
 
- protected:
-  void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
-                        DoneCallback done) override {
+  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+    CollectiveExecutor* col_exec = c->collective_executor();
+    OP_REQUIRES_ASYNC(
+        c, col_exec,
+        errors::Internal(
+            "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+            name_),
+        done);
     const Tensor& input = c->input(0);
     const Tensor& group_size = c->input(1);
     const Tensor& group_key = c->input(2);
@@ -728,6 +726,7 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
   }
 
  private:
+  string name_;
   DataType data_type_ = DT_INVALID;
   string communication_hint_;
   float timeout_seconds_ = 0;
diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc
index 56e2255ae99..bcdee71be18 100644
--- a/tensorflow/core/nccl/collective_communicator.cc
+++ b/tensorflow/core/nccl/collective_communicator.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 #include "tensorflow/core/nccl/collective_communicator.h"
 
+#include "tensorflow/core/framework/cancellation.h"
+
 #if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
 
 #include "absl/memory/memory.h"
@@ -77,7 +79,25 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
   auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
   auto participant = absl::make_unique<NcclManager::Participant>(
       compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
-      col_ctx->output, col_ctx->col_params.default_rank, std::move(done));
+      col_ctx->output, col_ctx->col_params.default_rank,
+      /*done_callback=*/nullptr);
+  CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
+  if (cancel_mgr == nullptr) {
+    participant->done_callback = std::move(done);
+  } else {
+    CancellationToken cancel_token = cancel_mgr->get_cancellation_token();
+    cancel_mgr->RegisterCallback(cancel_token, [this]() {
+      nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
+      nccl_manager_.Reset();
+    });
+    participant->done_callback = [cancel_mgr, cancel_token,
+                                  done = std::move(done)](const Status& s) {
+      // Do not block on deregistration since this can be invoked by
+      // NcclManager::StartAbort() in the cancellation callback.
+      cancel_mgr->TryDeregisterCallback(cancel_token);
+      done(s);
+    };
+  }
   NcclManager::Context context(
       nccl_collective_key, num_local_devices, num_global_devices,
       col_params.group.runtime_details.communicator_key,
diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc
index a31aafcdab1..eaa34d042ce 100644
--- a/tensorflow/core/nccl/nccl_manager.cc
+++ b/tensorflow/core/nccl/nccl_manager.cc
@@ -875,11 +875,12 @@ void NcclManager::StartAbort(const Status& s) {
     }
     item.second->Unref();
   }
-  // Abort ncclComm. Note that there could be multiple ncclComm per device, and
-  // ncclCommAbort contains cuda calls that requires device synchronization.
-  // That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1),
-  // so we need to abort all ncclComm in a concurrent fashion. This assumes that
-  // there's only one active NcclManager at a time.
+  // Abort ncclComm. Note that there could be multiple ncclComm per device,
+  // and ncclCommAbort contains cuda calls that requires device
+  // synchronization. That is a collective on nccl_comm_0 can block
+  // ncclCommAbort(nccl_comm_1), so we need to abort all ncclComm in a
+  // concurrent fashion. This assumes that there's only one active NcclManager
+  // at a time.
   UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
   int num_comms = 0;
   for (std::unique_ptr<Communicator>& communicator : communicators) {
diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py
index 669aae49b41..fe558bcae64 100644
--- a/tensorflow/python/kernel_tests/collective_ops_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_test.py
@@ -471,7 +471,29 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
     _setup_context()
     def_function.function(collective_fn)()
 
-  def testOpErrorNotAbort(self, collective_op, device, communication):
+
+class OpCancellationTest(test.TestCase, parameterized.TestCase):
+
+  def setUp(self):
+    _setup_context()
+    super().setUp()
+
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(
+              collective_op=[
+                  combinations.NamedObject('all_reduce',
+                                           CollectiveOpsV1.all_reduce),
+                  combinations.NamedObject('all_reduce_v2',
+                                           CollectiveOpsV2.all_reduce),
+                  combinations.NamedObject('all_gather',
+                                           CollectiveOpsV1.all_gather),
+                  combinations.NamedObject('all_gather_v2',
+                                           CollectiveOpsV2.all_gather),
+              ],
+              mode='eager'), device_combination))
+  def testOpErrorNotAbortIfNoCollective(self, collective_op, device,
+                                        communication):
     # Do not abort if there's no active collective ops. There could be
     # exceptions like EOF which we expect users to catch, aborting collective
     # ops on all op errors intervenes with this workflow.
@@ -504,9 +526,20 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
       f()
     collective_fn(constant_op.constant([1.]))
 
-  def testOpErrorAbort(self, collective_op, device, communication):
-    # Abort collective ops if there're active collective ops at the time of an
-    # op error. This is due to the inability to cancel collective ops, and op
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(
+              collective_op=[
+                  combinations.NamedObject('all_reduce',
+                                           CollectiveOpsV1.all_reduce),
+                  combinations.NamedObject('all_gather',
+                                           CollectiveOpsV1.all_gather),
+              ],
+              mode='eager'), device_combination))
+  def testOpErrorAbortWithCollective(self, collective_op, device,
+                                     communication):
+    # Abort v1 collective ops if there're active collective ops at the time of
+    # an op error. This is due to the inability to cancel collective ops, and op
     # errors may cause running collective ops to hang.
     dev0 = '/device:%s:0' % device
     group_size = 2
@@ -548,6 +581,71 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
             instance_key,
             communication_hint=communication)
 
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(
+              collective_op=[
+                  combinations.NamedObject('all_reduce_v2',
+                                           CollectiveOpsV2.all_reduce),
+                  combinations.NamedObject('all_gather_v2',
+                                           CollectiveOpsV2.all_gather),
+              ],
+              mode='eager'), device_combination))
+  def testOpErrorNotAbortWithCollective(self, collective_op, device,
+                                        communication):
+    # Do not abort v2 collective ops even if there're active collective ops at
+    # the time of an op error. We rely cancellation to terminate active
+    # collective ops.
+    dev0 = '/device:%s:0' % device
+    dev1 = '/device:%s:1' % device
+    group_size = 2
+    group_key = 100
+    instance_key = 100
+    in_tensor = constant_op.constant([1.])
+
+    @def_function.function
+    def collective_fn():
+      for device in [dev0, dev1]:
+        with ops.device(device):
+          collective_op(
+              in_tensor,
+              group_size,
+              group_key,
+              instance_key,
+              communication_hint=communication)
+
+    # Local params resolution cannot be cancelled yet, so we perform a normal
+    # collective so that the group is resolved.
+    collective_fn()
+
+    # Make the dataset sleep a while so that the collective is being executed
+    # when the EOF happens.
+    dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
+        dataset_testing.sleep(sleep_microseconds=200))
+
+    @def_function.function
+    def f():
+      # Launch a collective op that won't be able to finish to test cancellation
+      # when other ops error.
+      with ops.device(dev0):
+        ret = collective_op(
+            in_tensor,
+            group_size,
+            group_key,
+            instance_key,
+            communication_hint=communication)
+      iterator = iter(dataset)
+      next(iterator)
+      # This should raise EOF.
+      next(iterator)
+      return ret
+
+    with self.assertRaises(errors.OutOfRangeError):
+      f()
+    # Collective ops shouldn't be aborted and new collectives should be able to
+    # proceed.
+    collective_fn()
+
 
 @combinations.generate(
     combinations.times(