diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index d527a76a0f8..a31aafcdab1 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -739,6 +739,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { VLOG(2) << "call NcclAllReduce collective_key " << collective->collective_key << " participant " << p_idx + << " num_participants " << collective->participants.size() << " sendbuff " << sendbuff << " recvbuff " << recvbuff << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream << " cuda_stream " << cu_stream; @@ -849,7 +850,6 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { } void NcclManager::StartAbort(const Status& s) { - VLOG(1) << "NcclManager StartAbort"; absl::flat_hash_map collectives; std::vector> communicators; { @@ -864,6 +864,9 @@ void NcclManager::StartAbort(const Status& s) { collectives.swap(collectives_); communicators.swap(communicators_); } + VLOG(2) << "Aborted NcclManager " << this << " with " << collectives.size() + << " collectives and " << communicators.size() + << " comms with status " << s; // collectives_ contains pending launches that haven't been dispatched to // kernel launch threads, so we can simply invoke the done callbacks of them. for (const auto& item : collectives) { @@ -895,6 +898,12 @@ void NcclManager::StartAbort(const Status& s) { pending.Wait(); } +void NcclManager::Reset() { + mutex_lock l(mu_); + status_ = Status(); + VLOG(2) << "Reset NcclManager " << this; +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index d6906bd2e1a..b1d9dd62f94 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -195,6 +195,10 @@ class NcclManager { // launched with this NcclManager. void StartAbort(const Status& s); + // Resets a previously aborted NcclManager, making it available for future + // collectives. + void Reset(); + private: enum CollectiveType { kAllReduce = 1, diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index ff967175091..0d0d003d63f 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -17,8 +17,6 @@ limitations under the License. #include "absl/strings/str_format.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/nccl/nccl_manager.h" - #include #include #include @@ -27,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/nccl/nccl_manager.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/unbounded_work_queue.h" @@ -863,20 +862,19 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) { // environment, on a single node with multiple GPUS. So tests that rely // upon such simulation need to be skipped on the ROCm platform -TYPED_TEST(NcclManagerTest, Abort) { +TYPED_TEST(NcclManagerTest, AbortThenReset) { using NodeState = typename TestFixture::NodeState; using TestCase = typename TestFixture::TestCase; - int num_nodes = 2; + const int num_nodes = 2; std::vector nodes(num_nodes); // First do a normal all-reduce to simulate the the case when there're // multiple communicators. this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1); - // Use a new communicator_key, which uses a new set of ncclComm underneath. - string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); - string collective_key = "allreduce"; + const string collective_key = "allreduce"; ncclRedOp_t reduction_op = static_cast(0); - auto node_fn = [&](TestCase* test_case, int node) { + auto node_fn = [&](TestCase* test_case, int node, + const string& communicator_key) { auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node, /* local_rank */ 0); auto* info = device->tensorflow_gpu_device_info(); @@ -894,6 +892,8 @@ TYPED_TEST(NcclManagerTest, Abort) { nodes[node].nccl_manager.SignalMultiNodeReady(collective_key); }; + // Use a new communicator_key, which uses a new set of ncclComm underneath. + string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); // Do a normal all-reduce with this communicator key to initialize ncclComm. // This is because ncclCommInitRank waits for all ranks and is blocking. { @@ -903,7 +903,9 @@ TYPED_TEST(NcclManagerTest, Abort) { TensorShape({2, 3}), 0.0f)); for (int i = 0; i < num_nodes; ++i) { this->work_queue_->Schedule( - [&node_fn, &test_case, i]() { node_fn(test_case.get(), i); }); + [&node_fn, &test_case, i, communicator_key]() { + node_fn(test_case.get(), i, communicator_key); + }); } this->VerifyResults(test_case.get()); } @@ -914,16 +916,41 @@ TYPED_TEST(NcclManagerTest, Abort) { this->MakeReductionTestCase( /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, TensorShape({2, 3}), 0.0f)); - node_fn(test_case.get(), 0); + node_fn(test_case.get(), 0, communicator_key); Env::Default()->SleepForMicroseconds(1000000); - nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down")); + for (auto& node : nodes) { + node.nccl_manager.StartAbort(errors::Unavailable("peer down")); + } { mutex_lock l(test_case->mu); while (test_case->num_completed != 1) { test_case->done_cv.wait(l); } } + + // Reset the aborted NcclManager and then run another all-reduce with the + // resetted NcclManagers. + for (auto& node : nodes) { + node.nccl_manager.Reset(); + } + // Regenerate the communicator_key, because this is needed to create new + // communicators. + communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); + { + std::unique_ptr test_case( + this->MakeReductionTestCase( + /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, + TensorShape({2, 3}), 0.0f)); + for (int i = 0; i < num_nodes; ++i) { + this->work_queue_->Schedule( + [&node_fn, &test_case, i, communicator_key]() { + node_fn(test_case.get(), i, communicator_key); + }); + } + this->VerifyResults(test_case.get()); + } } + #endif } // namespace tensorflow