diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index bcdee71be18..233f1ecefd3 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -86,10 +86,15 @@ void NcclCommunicator::Enqueue(std::shared_ptr col_ctx, participant->done_callback = std::move(done); } else { CancellationToken cancel_token = cancel_mgr->get_cancellation_token(); - cancel_mgr->RegisterCallback(cancel_token, [this]() { - nccl_manager_.StartAbort(errors::Cancelled("op cancelled")); - nccl_manager_.Reset(); - }); + bool already_cancelled = + !cancel_mgr->RegisterCallback(cancel_token, [this]() { + nccl_manager_.StartAbort(errors::Cancelled("op cancelled")); + nccl_manager_.Reset(); + }); + if (already_cancelled) { + done(errors::Cancelled("op cancelled")); + return; + } participant->done_callback = [cancel_mgr, cancel_token, done = std::move(done)](const Status& s) { // Do not block on deregistration since this can be invoked by diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index bb6c04019e9..fdaf3213759 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -595,7 +595,6 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase): mode='eager'), device_combination)) def testOpErrorNotAbortWithCollective(self, collective_op, device, communication): - self.skipTest('b/173733368: currently it may timeout on guitar.') # Do not abort v2 collective ops even if there're active collective ops at # the time of an op error. We rely cancellation to terminate active # collective ops.