diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 8a059cf3d35..f5c2522f142 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -218,6 +218,9 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) { VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s; cem_->GetParamResolver()->StartAbort(s); remote_access_->StartAbort(s); + if (cem_->GetNcclCommunicator() != nullptr) { + cem_->GetNcclCommunicator()->StartAbort(s); + } } void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h index 9756215c3f3..ddfb135aeef 100644 --- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -102,7 +102,6 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { } NcclCommunicatorInterface* GetNcclCommunicator() const override { - LOG(FATAL) << "Unimplemented"; // Crash OK return nullptr; } diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index ddb906037a8..e71721ce080 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -164,7 +164,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr col_ctx, } void NcclCommunicator::StartAbort(const Status& s) { - CHECK(false) << "not implemented yet"; // Crash ok. + nccl_manager_.StartAbort(s); } } // namespace tensorflow diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index bb4e7c90a06..43c9b229450 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" #include "tensorflow/core/profiler/lib/annotated_traceme.h" #include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -279,6 +281,9 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, }); mutex_lock l(mu_); + if (!status_.ok()) { + return status_; + } if (collective->communicator_key.empty()) { // For single-node collectives, when the caller does not specify a @@ -487,6 +492,7 @@ void NcclManager::AddParticipant(std::unique_ptr participant, ncclRedOp_t reduction_op) { Collective* to_run = nullptr; DataType data_type; + Status nccl_manager_status; if (participant->input != nullptr) { data_type = participant->input->dtype(); } else { @@ -494,92 +500,100 @@ void NcclManager::AddParticipant(std::unique_ptr participant, } { mutex_lock l(mu_); - auto collective_it = collectives_.find(context.collective_key); - Collective* collective = nullptr; - if (collective_it == collectives_.end()) { - collective = - new Collective(context.collective_key, data_type, collective_type, - reduction_op, context.num_local_devices, - context.num_global_devices, context.communicator_key); - collectives_.emplace(context.collective_key, collective); - } else { - collective = collective_it->second; - } + nccl_manager_status = status_; + if (nccl_manager_status.ok()) { + auto collective_it = collectives_.find(context.collective_key); + Collective* collective = nullptr; + if (collective_it == collectives_.end()) { + collective = new Collective( + context.collective_key, data_type, collective_type, reduction_op, + context.num_local_devices, context.num_global_devices, + context.communicator_key); + collectives_.emplace(context.collective_key, collective); + } else { + collective = collective_it->second; + } - // Check `collective` is correct and consistent. - if (collective->status.ok() && !collective->single_node && - collective->communicator_key.empty()) { - collective->status = errors::Internal( - "Collective ", reduction_op, " is multi node with num_local_devices=", - collective->num_local_devices, - " and num_global_devices=", collective->num_global_devices, - " but has an empty communicator_key"); - } - if (collective->status.ok() && collective->communicator_key.size() != - context.communicator_key.size()) { - collective->status = - errors::Internal("Collective ", reduction_op, - " mismatch in member communicator_key with size ", - collective->communicator_key.size(), - " and arg communicator_key with size ", - context.communicator_key.size()); - } - if (collective->status.ok() && collective->type != collective_type) { - collective->status = errors::Internal( - "Collective ", reduction_op, " previously initialized with type ", - collective->type, " but now got type ", collective_type); - } - if (collective->status.ok() && - collective->num_global_devices != context.num_global_devices) { - collective->status = - errors::Internal("Collective ", reduction_op, - " previously initialized with num_global_devices ", - collective->num_global_devices, " but now got ", - context.num_global_devices); - } - if (collective->status.ok() && - collective->num_local_devices != context.num_local_devices) { - collective->status = - errors::Internal("Collective ", reduction_op, - "previously initialized with num_local_devices ", - collective->num_local_devices, " but now got ", - context.num_local_devices); - } - if (collective->status.ok() && - collective->participants.size() >= collective->num_local_devices) { - collective->status = errors::Internal( - "Collective ", reduction_op, " expected ", - collective->num_local_devices, " participants but now has ", - collective->participants.size(), - " with one more participant being added"); - } - if (collective->status.ok() && collective->root_rank >= 0 && - context.source_rank >= 0 && - collective->root_rank != context.source_rank) { - collective->status = errors::Internal( - "Collective ", collective->collective_key, " already has root_rank ", - collective->root_rank, " but new participant has root_rank ", - context.source_rank); - } - if (collective->status.ok() && - !kValidDataTypes.Contains(collective->data_type)) { - collective->status = errors::Internal( - "Collective ", collective->collective_key, - " expected data types compatible with NCCL but instead got ", - DataTypeString(collective->data_type)); - } + // Check `collective` is correct and consistent. + if (collective->status.ok() && !collective->single_node && + collective->communicator_key.empty()) { + collective->status = errors::Internal( + "Collective ", reduction_op, + " is multi node with num_local_devices=", + collective->num_local_devices, + " and num_global_devices=", collective->num_global_devices, + " but has an empty communicator_key"); + } + if (collective->status.ok() && collective->communicator_key.size() != + context.communicator_key.size()) { + collective->status = + errors::Internal("Collective ", reduction_op, + " mismatch in member communicator_key with size ", + collective->communicator_key.size(), + " and arg communicator_key with size ", + context.communicator_key.size()); + } + if (collective->status.ok() && collective->type != collective_type) { + collective->status = errors::Internal( + "Collective ", reduction_op, " previously initialized with type ", + collective->type, " but now got type ", collective_type); + } + if (collective->status.ok() && + collective->num_global_devices != context.num_global_devices) { + collective->status = + errors::Internal("Collective ", reduction_op, + " previously initialized with num_global_devices ", + collective->num_global_devices, " but now got ", + context.num_global_devices); + } + if (collective->status.ok() && + collective->num_local_devices != context.num_local_devices) { + collective->status = + errors::Internal("Collective ", reduction_op, + "previously initialized with num_local_devices ", + collective->num_local_devices, " but now got ", + context.num_local_devices); + } + if (collective->status.ok() && + collective->participants.size() >= collective->num_local_devices) { + collective->status = errors::Internal( + "Collective ", reduction_op, " expected ", + collective->num_local_devices, " participants but now has ", + collective->participants.size(), + " with one more participant being added"); + } + if (collective->status.ok() && collective->root_rank >= 0 && + context.source_rank >= 0 && + collective->root_rank != context.source_rank) { + collective->status = errors::Internal( + "Collective ", collective->collective_key, + " already has root_rank ", collective->root_rank, + " but new participant has root_rank ", context.source_rank); + } + if (collective->status.ok() && + !kValidDataTypes.Contains(collective->data_type)) { + collective->status = errors::Internal( + "Collective ", collective->collective_key, + " expected data types compatible with NCCL but instead got ", + DataTypeString(collective->data_type)); + } - if (context.source_rank >= 0) { - collective->root_rank = context.source_rank; - } - collective->participants.emplace_back(std::move(participant)); - ++collective->available_participants; + if (context.source_rank >= 0) { + collective->root_rank = context.source_rank; + } - if (CheckReady(context.collective_key, collective)) { - to_run = collective; + collective->participants.emplace_back(std::move(participant)); + ++collective->available_participants; + + if (CheckReady(context.collective_key, collective)) { + to_run = collective; + } } } - + if (!nccl_manager_status.ok()) { + participant->done_callback(nccl_manager_status); + return; + } if (to_run != nullptr) RunCollective(to_run); } @@ -834,6 +848,52 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { } } +void NcclManager::StartAbort(const Status& s) { + VLOG(1) << "NcclManager StartAbort"; + absl::flat_hash_map collectives; + // After status_ is set to a non-OK one, there should be no further + // modifications to collectives_. + { + mutex_lock l(mu_); + if (!status_.ok()) { + LOG(WARNING) + << "NcclManager already aborted, ignoring subsequent StartAbort with " + << s; + return; + } + status_ = s; + collectives.swap(collectives_); + } + // collectives_ contains pending launches that haven't been dispatched to + // kernel launch threads, so we can simply invoke the done callbacks of them. + for (const auto& item : collectives) { + for (const std::unique_ptr& p : item.second->participants) { + p->done_callback(s); + } + item.second->Unref(); + } + // Abort ncclComm. Note that there could be multiple ncclComm per device, and + // ncclCommAbort contains cuda calls that requires device synchronization. + // That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1), + // so we need to abort all ncclComm in a concurrent fashion. This assumes that + // there's only one active NcclManager at a time. + UnboundedWorkQueue queue(Env::Default(), "nccl_abort"); + int num_comms = 0; + for (std::unique_ptr& communicator : communicators_) { + num_comms += communicator->members.size(); + } + BlockingCounter pending(num_comms); + for (std::unique_ptr& communicator : communicators_) { + for (const CommunicatorMember& member : communicator->members) { + queue.Schedule([&member, &pending]() { + ncclCommAbort(member.nccl_comm); + pending.DecrementCount(); + }); + } + } + pending.Wait(); +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index a8bdce87081..b91ef86042a 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -202,6 +202,10 @@ class NcclManager { // function. void SignalMultiNodeReady(const string& collective_key); + // Aborts all collectives. After abortion, no further collectives can be + // launched with this NcclManager. + void StartAbort(const Status& s); + private: enum CollectiveType { kAllReduce = 1, @@ -257,6 +261,8 @@ class NcclManager { std::vector> communicators_; + Status status_ TF_GUARDED_BY(mu_); + TF_DISALLOW_COPY_AND_ASSIGN(NcclManager); }; diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index a76b0494bab..d16eefa6f72 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/nccl/nccl_manager.h" @@ -300,8 +302,14 @@ class NcclManagerTest : public ::testing::Test { void RunMultiNodeAllReduceTest(const int num_nodes, const int num_ranks_per_node) { - const int num_global_ranks = num_nodes * num_ranks_per_node; std::vector node_states(num_nodes); + RunMultiNodeAllReduceTest(node_states, num_ranks_per_node); + } + + void RunMultiNodeAllReduceTest(std::vector& node_states, + const int num_ranks_per_node) { + const int num_nodes = node_states.size(); + const int num_global_ranks = num_nodes * num_ranks_per_node; const string collective_key = "allreduce"; // The NcclManagers in this test synchronize in real-time, so we need to run // each node's code in a separate thread. @@ -842,6 +850,68 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) { this->VerifyError(test_case.get()); } +TYPED_TEST(NcclManagerTest, Abort) { + using NodeState = typename TestFixture::NodeState; + using TestCase = typename TestFixture::TestCase; + int num_nodes = 2; + std::vector nodes(num_nodes); + // First do a normal all-reduce to simulate the the case when there're + // multiple communicators. + this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1); + + // Use a new communicator_key, which uses a new set of ncclComm underneath. + string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); + string collective_key = "allreduce"; + ncclRedOp_t reduction_op = static_cast(0); + auto node_fn = [&](TestCase* test_case, int node) { + auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node, + /* local_rank */ 0); + auto* info = device->tensorflow_gpu_device_info(); + auto* stream = device->tensorflow_gpu_device_info()->stream; + auto participant = absl::make_unique( + device->executor(), stream, info, &test_case->ins[node], + &test_case->outs[node], /* global_rank */ node, + this->CreateDoneCallback(test_case)); + nodes[node].nccl_manager.AddToAllReduce( + std::move(participant), + {collective_key, /* num_local_devices */ 1, + /* num_global_devices */ num_nodes, communicator_key, + /*source_rank=*/-1}, + reduction_op); + nodes[node].nccl_manager.SignalMultiNodeReady(collective_key); + }; + + // Do a normal all-reduce with this communicator key to initialize ncclComm. + // This is because ncclCommInitRank waits for all ranks and is blocking. + { + std::unique_ptr test_case( + this->MakeReductionTestCase( + /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, + TensorShape({2, 3}), 0.0f)); + for (int i = 0; i < num_nodes; ++i) { + this->work_queue_->Schedule( + [&node_fn, &test_case, i]() { node_fn(test_case.get(), i); }); + } + this->VerifyResults(test_case.get()); + } + + // A hanging all-reduce. + ASSERT_GT(num_nodes, 1); + std::unique_ptr test_case( + this->MakeReductionTestCase( + /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, + TensorShape({2, 3}), 0.0f)); + node_fn(test_case.get(), 0); + Env::Default()->SleepForMicroseconds(1000000); + nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down")); + { + mutex_lock l(test_case->mu); + while (test_case->num_completed != 1) { + test_case->done_cv.wait(l); + } + } +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/python/ops/collective_ops_gpu_test.py b/tensorflow/python/ops/collective_ops_gpu_test.py index 87758a314b2..a1ac4a60320 100644 --- a/tensorflow/python/ops/collective_ops_gpu_test.py +++ b/tensorflow/python/ops/collective_ops_gpu_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import os +import threading +import time from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -27,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import collective_ops from tensorflow.python.platform import test @@ -301,6 +304,65 @@ class CollectiveOpGPUTest(test.TestCase): [1.], group_size=1, group_key=0, instance_key=0, merge_op='Add', final_op='Id', communication_hint='NCCL') + @test_util.run_v2_only + def testAbortNccl(self): + self._setup_context(num_gpus=2) + + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant(1.) + + # First perform a normal collective to finish resolution. + def collective_fn(): + for device in ['GPU:0', 'GPU:1']: + with ops.device(device): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + 'Add', + 'Id', + communication_hint='nccl') + + def_function.function(collective_fn)() + + # Launch a collective that hangs, and abort the collective executor after + # the launch. + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down') + + t = threading.Thread(target=abort_fn) + t.start() + + with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + 'Add', + 'Id', + communication_hint='nccl') + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + 'Add', + 'Id', + communication_hint='nccl') + + t.join() + # Reset the context in order to reset the collective executor. + context._reset_context() # pylint: disable=protected-access + def_function.function(collective_fn)() + if __name__ == '__main__': test.main()