From f2ebefba658fb4f424842f03d15f694cb917110f Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Fri, 18 Sep 2020 13:12:47 -0700 Subject: [PATCH] Add NcclManager::StartAbort This allows to abort NcclManager when the cluster is unhealthy. After the abortion, any subsequent call to NcclManager will error immediately. After calling ncclCommAbort, ongoing and subsequent nccl launches should error. Note that this cannot abort NCCL initialization yet. PiperOrigin-RevId: 332512565 Change-Id: I1cc53078f6cddeeea566ed18edbb45f7a2d833a0 --- .../base_collective_executor.cc | 3 + .../test_collective_executor_mgr.h | 1 - .../core/nccl/collective_communicator.cc | 2 +- tensorflow/core/nccl/nccl_manager.cc | 222 +++++++++++------- tensorflow/core/nccl/nccl_manager.h | 6 + tensorflow/core/nccl/nccl_manager_test.cc | 72 +++++- .../python/ops/collective_ops_gpu_test.py | 62 +++++ 7 files changed, 284 insertions(+), 84 deletions(-) diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 8a059cf3d35..f5c2522f142 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -218,6 +218,9 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) { VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s; cem_->GetParamResolver()->StartAbort(s); remote_access_->StartAbort(s); + if (cem_->GetNcclCommunicator() != nullptr) { + cem_->GetNcclCommunicator()->StartAbort(s); + } } void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h index 9756215c3f3..ddfb135aeef 100644 --- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -102,7 +102,6 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { } NcclCommunicatorInterface* GetNcclCommunicator() const override { - LOG(FATAL) << "Unimplemented"; // Crash OK return nullptr; } diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index ddb906037a8..e71721ce080 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -164,7 +164,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr col_ctx, } void NcclCommunicator::StartAbort(const Status& s) { - CHECK(false) << "not implemented yet"; // Crash ok. + nccl_manager_.StartAbort(s); } } // namespace tensorflow diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index bb4e7c90a06..43c9b229450 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" #include "tensorflow/core/profiler/lib/annotated_traceme.h" #include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -279,6 +281,9 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, }); mutex_lock l(mu_); + if (!status_.ok()) { + return status_; + } if (collective->communicator_key.empty()) { // For single-node collectives, when the caller does not specify a @@ -487,6 +492,7 @@ void NcclManager::AddParticipant(std::unique_ptr participant, ncclRedOp_t reduction_op) { Collective* to_run = nullptr; DataType data_type; + Status nccl_manager_status; if (participant->input != nullptr) { data_type = participant->input->dtype(); } else { @@ -494,92 +500,100 @@ void NcclManager::AddParticipant(std::unique_ptr participant, } { mutex_lock l(mu_); - auto collective_it = collectives_.find(context.collective_key); - Collective* collective = nullptr; - if (collective_it == collectives_.end()) { - collective = - new Collective(context.collective_key, data_type, collective_type, - reduction_op, context.num_local_devices, - context.num_global_devices, context.communicator_key); - collectives_.emplace(context.collective_key, collective); - } else { - collective = collective_it->second; - } + nccl_manager_status = status_; + if (nccl_manager_status.ok()) { + auto collective_it = collectives_.find(context.collective_key); + Collective* collective = nullptr; + if (collective_it == collectives_.end()) { + collective = new Collective( + context.collective_key, data_type, collective_type, reduction_op, + context.num_local_devices, context.num_global_devices, + context.communicator_key); + collectives_.emplace(context.collective_key, collective); + } else { + collective = collective_it->second; + } - // Check `collective` is correct and consistent. - if (collective->status.ok() && !collective->single_node && - collective->communicator_key.empty()) { - collective->status = errors::Internal( - "Collective ", reduction_op, " is multi node with num_local_devices=", - collective->num_local_devices, - " and num_global_devices=", collective->num_global_devices, - " but has an empty communicator_key"); - } - if (collective->status.ok() && collective->communicator_key.size() != - context.communicator_key.size()) { - collective->status = - errors::Internal("Collective ", reduction_op, - " mismatch in member communicator_key with size ", - collective->communicator_key.size(), - " and arg communicator_key with size ", - context.communicator_key.size()); - } - if (collective->status.ok() && collective->type != collective_type) { - collective->status = errors::Internal( - "Collective ", reduction_op, " previously initialized with type ", - collective->type, " but now got type ", collective_type); - } - if (collective->status.ok() && - collective->num_global_devices != context.num_global_devices) { - collective->status = - errors::Internal("Collective ", reduction_op, - " previously initialized with num_global_devices ", - collective->num_global_devices, " but now got ", - context.num_global_devices); - } - if (collective->status.ok() && - collective->num_local_devices != context.num_local_devices) { - collective->status = - errors::Internal("Collective ", reduction_op, - "previously initialized with num_local_devices ", - collective->num_local_devices, " but now got ", - context.num_local_devices); - } - if (collective->status.ok() && - collective->participants.size() >= collective->num_local_devices) { - collective->status = errors::Internal( - "Collective ", reduction_op, " expected ", - collective->num_local_devices, " participants but now has ", - collective->participants.size(), - " with one more participant being added"); - } - if (collective->status.ok() && collective->root_rank >= 0 && - context.source_rank >= 0 && - collective->root_rank != context.source_rank) { - collective->status = errors::Internal( - "Collective ", collective->collective_key, " already has root_rank ", - collective->root_rank, " but new participant has root_rank ", - context.source_rank); - } - if (collective->status.ok() && - !kValidDataTypes.Contains(collective->data_type)) { - collective->status = errors::Internal( - "Collective ", collective->collective_key, - " expected data types compatible with NCCL but instead got ", - DataTypeString(collective->data_type)); - } + // Check `collective` is correct and consistent. + if (collective->status.ok() && !collective->single_node && + collective->communicator_key.empty()) { + collective->status = errors::Internal( + "Collective ", reduction_op, + " is multi node with num_local_devices=", + collective->num_local_devices, + " and num_global_devices=", collective->num_global_devices, + " but has an empty communicator_key"); + } + if (collective->status.ok() && collective->communicator_key.size() != + context.communicator_key.size()) { + collective->status = + errors::Internal("Collective ", reduction_op, + " mismatch in member communicator_key with size ", + collective->communicator_key.size(), + " and arg communicator_key with size ", + context.communicator_key.size()); + } + if (collective->status.ok() && collective->type != collective_type) { + collective->status = errors::Internal( + "Collective ", reduction_op, " previously initialized with type ", + collective->type, " but now got type ", collective_type); + } + if (collective->status.ok() && + collective->num_global_devices != context.num_global_devices) { + collective->status = + errors::Internal("Collective ", reduction_op, + " previously initialized with num_global_devices ", + collective->num_global_devices, " but now got ", + context.num_global_devices); + } + if (collective->status.ok() && + collective->num_local_devices != context.num_local_devices) { + collective->status = + errors::Internal("Collective ", reduction_op, + "previously initialized with num_local_devices ", + collective->num_local_devices, " but now got ", + context.num_local_devices); + } + if (collective->status.ok() && + collective->participants.size() >= collective->num_local_devices) { + collective->status = errors::Internal( + "Collective ", reduction_op, " expected ", + collective->num_local_devices, " participants but now has ", + collective->participants.size(), + " with one more participant being added"); + } + if (collective->status.ok() && collective->root_rank >= 0 && + context.source_rank >= 0 && + collective->root_rank != context.source_rank) { + collective->status = errors::Internal( + "Collective ", collective->collective_key, + " already has root_rank ", collective->root_rank, + " but new participant has root_rank ", context.source_rank); + } + if (collective->status.ok() && + !kValidDataTypes.Contains(collective->data_type)) { + collective->status = errors::Internal( + "Collective ", collective->collective_key, + " expected data types compatible with NCCL but instead got ", + DataTypeString(collective->data_type)); + } - if (context.source_rank >= 0) { - collective->root_rank = context.source_rank; - } - collective->participants.emplace_back(std::move(participant)); - ++collective->available_participants; + if (context.source_rank >= 0) { + collective->root_rank = context.source_rank; + } - if (CheckReady(context.collective_key, collective)) { - to_run = collective; + collective->participants.emplace_back(std::move(participant)); + ++collective->available_participants; + + if (CheckReady(context.collective_key, collective)) { + to_run = collective; + } } } - + if (!nccl_manager_status.ok()) { + participant->done_callback(nccl_manager_status); + return; + } if (to_run != nullptr) RunCollective(to_run); } @@ -834,6 +848,52 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { } } +void NcclManager::StartAbort(const Status& s) { + VLOG(1) << "NcclManager StartAbort"; + absl::flat_hash_map collectives; + // After status_ is set to a non-OK one, there should be no further + // modifications to collectives_. + { + mutex_lock l(mu_); + if (!status_.ok()) { + LOG(WARNING) + << "NcclManager already aborted, ignoring subsequent StartAbort with " + << s; + return; + } + status_ = s; + collectives.swap(collectives_); + } + // collectives_ contains pending launches that haven't been dispatched to + // kernel launch threads, so we can simply invoke the done callbacks of them. + for (const auto& item : collectives) { + for (const std::unique_ptr& p : item.second->participants) { + p->done_callback(s); + } + item.second->Unref(); + } + // Abort ncclComm. Note that there could be multiple ncclComm per device, and + // ncclCommAbort contains cuda calls that requires device synchronization. + // That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1), + // so we need to abort all ncclComm in a concurrent fashion. This assumes that + // there's only one active NcclManager at a time. + UnboundedWorkQueue queue(Env::Default(), "nccl_abort"); + int num_comms = 0; + for (std::unique_ptr& communicator : communicators_) { + num_comms += communicator->members.size(); + } + BlockingCounter pending(num_comms); + for (std::unique_ptr& communicator : communicators_) { + for (const CommunicatorMember& member : communicator->members) { + queue.Schedule([&member, &pending]() { + ncclCommAbort(member.nccl_comm); + pending.DecrementCount(); + }); + } + } + pending.Wait(); +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index a8bdce87081..b91ef86042a 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -202,6 +202,10 @@ class NcclManager { // function. void SignalMultiNodeReady(const string& collective_key); + // Aborts all collectives. After abortion, no further collectives can be + // launched with this NcclManager. + void StartAbort(const Status& s); + private: enum CollectiveType { kAllReduce = 1, @@ -257,6 +261,8 @@ class NcclManager { std::vector> communicators_; + Status status_ TF_GUARDED_BY(mu_); + TF_DISALLOW_COPY_AND_ASSIGN(NcclManager); }; diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index a76b0494bab..d16eefa6f72 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/nccl/nccl_manager.h" @@ -300,8 +302,14 @@ class NcclManagerTest : public ::testing::Test { void RunMultiNodeAllReduceTest(const int num_nodes, const int num_ranks_per_node) { - const int num_global_ranks = num_nodes * num_ranks_per_node; std::vector node_states(num_nodes); + RunMultiNodeAllReduceTest(node_states, num_ranks_per_node); + } + + void RunMultiNodeAllReduceTest(std::vector& node_states, + const int num_ranks_per_node) { + const int num_nodes = node_states.size(); + const int num_global_ranks = num_nodes * num_ranks_per_node; const string collective_key = "allreduce"; // The NcclManagers in this test synchronize in real-time, so we need to run // each node's code in a separate thread. @@ -842,6 +850,68 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) { this->VerifyError(test_case.get()); } +TYPED_TEST(NcclManagerTest, Abort) { + using NodeState = typename TestFixture::NodeState; + using TestCase = typename TestFixture::TestCase; + int num_nodes = 2; + std::vector nodes(num_nodes); + // First do a normal all-reduce to simulate the the case when there're + // multiple communicators. + this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1); + + // Use a new communicator_key, which uses a new set of ncclComm underneath. + string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); + string collective_key = "allreduce"; + ncclRedOp_t reduction_op = static_cast(0); + auto node_fn = [&](TestCase* test_case, int node) { + auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node, + /* local_rank */ 0); + auto* info = device->tensorflow_gpu_device_info(); + auto* stream = device->tensorflow_gpu_device_info()->stream; + auto participant = absl::make_unique( + device->executor(), stream, info, &test_case->ins[node], + &test_case->outs[node], /* global_rank */ node, + this->CreateDoneCallback(test_case)); + nodes[node].nccl_manager.AddToAllReduce( + std::move(participant), + {collective_key, /* num_local_devices */ 1, + /* num_global_devices */ num_nodes, communicator_key, + /*source_rank=*/-1}, + reduction_op); + nodes[node].nccl_manager.SignalMultiNodeReady(collective_key); + }; + + // Do a normal all-reduce with this communicator key to initialize ncclComm. + // This is because ncclCommInitRank waits for all ranks and is blocking. + { + std::unique_ptr test_case( + this->MakeReductionTestCase( + /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, + TensorShape({2, 3}), 0.0f)); + for (int i = 0; i < num_nodes; ++i) { + this->work_queue_->Schedule( + [&node_fn, &test_case, i]() { node_fn(test_case.get(), i); }); + } + this->VerifyResults(test_case.get()); + } + + // A hanging all-reduce. + ASSERT_GT(num_nodes, 1); + std::unique_ptr test_case( + this->MakeReductionTestCase( + /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, + TensorShape({2, 3}), 0.0f)); + node_fn(test_case.get(), 0); + Env::Default()->SleepForMicroseconds(1000000); + nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down")); + { + mutex_lock l(test_case->mu); + while (test_case->num_completed != 1) { + test_case->done_cv.wait(l); + } + } +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/python/ops/collective_ops_gpu_test.py b/tensorflow/python/ops/collective_ops_gpu_test.py index 87758a314b2..a1ac4a60320 100644 --- a/tensorflow/python/ops/collective_ops_gpu_test.py +++ b/tensorflow/python/ops/collective_ops_gpu_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import os +import threading +import time from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -27,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import collective_ops from tensorflow.python.platform import test @@ -301,6 +304,65 @@ class CollectiveOpGPUTest(test.TestCase): [1.], group_size=1, group_key=0, instance_key=0, merge_op='Add', final_op='Id', communication_hint='NCCL') + @test_util.run_v2_only + def testAbortNccl(self): + self._setup_context(num_gpus=2) + + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant(1.) + + # First perform a normal collective to finish resolution. + def collective_fn(): + for device in ['GPU:0', 'GPU:1']: + with ops.device(device): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + 'Add', + 'Id', + communication_hint='nccl') + + def_function.function(collective_fn)() + + # Launch a collective that hangs, and abort the collective executor after + # the launch. + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down') + + t = threading.Thread(target=abort_fn) + t.start() + + with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + 'Add', + 'Id', + communication_hint='nccl') + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + 'Add', + 'Id', + communication_hint='nccl') + + t.join() + # Reset the context in order to reset the collective executor. + context._reset_context() # pylint: disable=protected-access + def_function.function(collective_fn)() + if __name__ == '__main__': test.main()