Fix NCCL cancellation registration
Return value of CancellationManager::RegisterCallback should be checked. PiperOrigin-RevId: 344901398 Change-Id: I471be5c74f72283a5757a5ca4d4040e95eb57d51
This commit is contained in:
parent
5994e13bbd
commit
4948648cbf
@ -86,10 +86,15 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
participant->done_callback = std::move(done);
|
||||
} else {
|
||||
CancellationToken cancel_token = cancel_mgr->get_cancellation_token();
|
||||
cancel_mgr->RegisterCallback(cancel_token, [this]() {
|
||||
nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
|
||||
nccl_manager_.Reset();
|
||||
});
|
||||
bool already_cancelled =
|
||||
!cancel_mgr->RegisterCallback(cancel_token, [this]() {
|
||||
nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
|
||||
nccl_manager_.Reset();
|
||||
});
|
||||
if (already_cancelled) {
|
||||
done(errors::Cancelled("op cancelled"));
|
||||
return;
|
||||
}
|
||||
participant->done_callback = [cancel_mgr, cancel_token,
|
||||
done = std::move(done)](const Status& s) {
|
||||
// Do not block on deregistration since this can be invoked by
|
||||
|
||||
@ -595,7 +595,6 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase):
|
||||
mode='eager'), device_combination))
|
||||
def testOpErrorNotAbortWithCollective(self, collective_op, device,
|
||||
communication):
|
||||
self.skipTest('b/173733368: currently it may timeout on guitar.')
|
||||
# Do not abort v2 collective ops even if there're active collective ops at
|
||||
# the time of an op error. We rely cancellation to terminate active
|
||||
# collective ops.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user