Fix NCCL cancellation registration

Return value of CancellationManager::RegisterCallback should be checked.

PiperOrigin-RevId: 344901398
Change-Id: I471be5c74f72283a5757a5ca4d4040e95eb57d51
This commit is contained in:
Ran Chen 2020-11-30 15:32:17 -08:00 committed by TensorFlower Gardener
parent 5994e13bbd
commit 4948648cbf
2 changed files with 9 additions and 5 deletions

View File

@ -86,10 +86,15 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
participant->done_callback = std::move(done);
} else {
CancellationToken cancel_token = cancel_mgr->get_cancellation_token();
cancel_mgr->RegisterCallback(cancel_token, [this]() {
nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
nccl_manager_.Reset();
});
bool already_cancelled =
!cancel_mgr->RegisterCallback(cancel_token, [this]() {
nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
nccl_manager_.Reset();
});
if (already_cancelled) {
done(errors::Cancelled("op cancelled"));
return;
}
participant->done_callback = [cancel_mgr, cancel_token,
done = std::move(done)](const Status& s) {
// Do not block on deregistration since this can be invoked by

View File

@ -595,7 +595,6 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase):
mode='eager'), device_combination))
def testOpErrorNotAbortWithCollective(self, collective_op, device,
communication):
self.skipTest('b/173733368: currently it may timeout on guitar.')
# Do not abort v2 collective ops even if there're active collective ops at
# the time of an op error. We rely cancellation to terminate active
# collective ops.