From 322ec054af88b612d8f2a783eda2a4eb8f16e598 Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Fri, 26 Jul 2019 15:57:22 -0700 Subject: [PATCH] Sort NCCL collective participants by rank in addition to stream executor. Before this change, if participants in a NCCL collective were on the same GPU, their ordering was non-deterministic when launching the `ncclAllReduce` kernels. However, `NcclManager::LoopKernelLaunches` picked the `Participant` whose index in the list of participants matched the `CommunicatorMember`'s index. So it was possible to associate a `Participant` with a communicator of the incorrect rank. The test added in this change, which launches broadcast kernels from concurrent threads, tickles this bug. We did not run into this bug before because the previous concurrent tests only launch all-reduce kernels, in which all ranks play the same role. The fix to the issue is to sort participants on the same GPU, i.e. with the same stream executor, based on their rank, thereby ensuring that the `Participant` <-> `CommunicatorMember` matching is correct at launch time. PiperOrigin-RevId: 260230221 --- tensorflow/core/nccl/nccl_manager.cc | 3 + tensorflow/core/nccl/nccl_manager_test.cc | 138 +++++++++++++--------- 2 files changed, 87 insertions(+), 54 deletions(-) diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index 20ba3caf9a5..e740241ea40 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -209,6 +209,9 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, std::sort(collective->participants.begin(), collective->participants.end(), [](const std::unique_ptr& a, const std::unique_ptr& b) { + if (a->executor == b->executor) { + return a->global_rank < b->global_rank; + } return a->executor < b->executor; }); diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index 161a88937c3..ece2f16d6b4 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" namespace tensorflow { @@ -60,7 +61,8 @@ class NcclManagerTest : public ::testing::Test { mutex mu; Status final_status; - int num_completed = 0; + int num_completed GUARDED_BY(mu) = 0; + condition_variable done_cv; }; static void SetUpTestSuite() { @@ -68,13 +70,20 @@ class NcclManagerTest : public ::testing::Test { setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */); devices_ = new std::vector>(GetGPUDevices()); LOG(INFO) << "Running test with " << devices_->size() << " gpus"; + work_queue_ = new UnboundedWorkQueue(Env::Default(), "nccl_manager_test"); } - void SetUp() override { ASSERT_GT(devices_->size(), 0) << "No GPUs found"; } + void SetUp() override { + ASSERT_GT(devices_->size(), 0) << "No GPUs found"; + ASSERT_NE(work_queue_, nullptr); + } static int32 NumGPUs() { return static_cast(devices_->size()); } - static void TearDownTestSuite() { delete devices_; } + static void TearDownTestSuite() { + delete devices_; + delete work_queue_; + } TestCase* MakeReductionTestCase(int num_nodes, int num_ranks_per_node, ncclRedOp_t reduction_op, TensorShape shape, @@ -221,13 +230,10 @@ class NcclManagerTest : public ::testing::Test { // Waits for the done callback to be called for each participant. void WaitForTestCompletion(TestCase* test_case) { - test_case->mu.lock(); + mutex_lock l(test_case->mu); while (test_case->num_completed != test_case->outs.size()) { - test_case->mu.unlock(); - Env::Default()->SleepForMicroseconds(10); - test_case->mu.lock(); + test_case->done_cv.wait(l); } - test_case->mu.unlock(); } void VerifyResults(TestCase* test_case) { @@ -259,12 +265,15 @@ class NcclManagerTest : public ::testing::Test { NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) { return [this, test_case](Status s) { mutex_lock l(test_case->mu); - ++test_case->num_completed; test_case->final_status.Update(s); + if (++test_case->num_completed == test_case->outs.size()) { + test_case->done_cv.notify_one(); + } }; } - void RunMultiNodeTest(const int num_nodes, const int num_ranks_per_node) { + void RunMultiNodeAllReduceTest(const int num_nodes, + const int num_ranks_per_node) { const int num_global_ranks = num_nodes * num_ranks_per_node; std::vector nccl_managers(num_nodes); const string collective_key = "allreduce"; @@ -272,7 +281,6 @@ class NcclManagerTest : public ::testing::Test { // each node's code in a separate thread. // Specifically, the call to ncclGroupEnd() after calling ncclCommInitRank // waits for all communicators before returning. - thread::ThreadPool pool(Env::Default(), "test_multi_node_nccl", num_nodes); // First, initialize the communicator_key used for this collective. const string communicator_key = nccl_managers[0].GenerateCommunicatorKey(); @@ -308,7 +316,7 @@ class NcclManagerTest : public ::testing::Test { // Signal collective ready to launch at this node. nccl_managers[node].SignalMultiNodeReady(collective_key); }; - pool.Schedule(node_fn); + this->work_queue_->Schedule(node_fn); } VLOG(2) << "Verifying results"; @@ -316,10 +324,52 @@ class NcclManagerTest : public ::testing::Test { } } + void RunBroadcastTest(const int num_ranks, const int src_rank, + const bool in_place) { + std::unique_ptr test_case(this->MakeBroadcastTestCase( + /*num_nodes=*/1, num_ranks, TensorShape({5, 6}), /*src_node=*/0, + src_rank, in_place)); + auto done = this->CreateDoneCallback(test_case.get()); + for (int rank = 0; rank < num_ranks; ++rank) { + // Launch each rank in a separate thread to test concurrent, + // randomly-ordered calls into NcclManager. + this->work_queue_->Schedule( + [this, num_ranks, src_rank, rank, &test_case, &done]() { + auto* device = this->GetDevice(rank); + auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* stream = device->tensorflow_gpu_device_info()->stream; + auto* input = rank == src_rank ? &test_case->ins[rank] : nullptr; + auto* output = test_case->outs[rank].NumElements() == 0 + ? nullptr + : &test_case->outs[rank]; + auto participant = absl::make_unique( + device->executor(), stream, event_mgr, device->gpu_id(), input, + output, rank, done); + if (rank == src_rank) { + NcclManager::instance()->AddBroadcastSend( + std::move(participant), + {"broadcast", /*num_local_devices=*/num_ranks, + /*num_global_devices=*/num_ranks, + /*communicator_key=*/""}); + } else { + NcclManager::instance()->AddBroadcastRecv( + std::move(participant), + {"broadcast", /*num_local_devices=*/num_ranks, + /*num_global_devices=*/num_ranks, + /*communicator_key=*/""}); + } + }); + } + + this->VerifyResults(test_case.get()); + } + static BaseGPUDevice* GetDevice(size_t rank) { return devices_->at(rank % devices_->size()).get(); } + static UnboundedWorkQueue* work_queue_; + private: static Allocator* GpuAllocator(BaseGPUDevice* device) { return device->GetAllocator(AllocatorAttributes()); @@ -331,7 +381,6 @@ class NcclManagerTest : public ::testing::Test { return typed; } - private: static std::vector>* devices_; static const DataType data_type_; static const Scalar max_; @@ -346,6 +395,8 @@ const DataType NcclManagerTest::data_type_ = template const Scalar NcclManagerTest::max_ = Eigen::NumTraits::highest(); +template +UnboundedWorkQueue* NcclManagerTest::work_queue_ = nullptr; // Instantiate tests for float and double. using TypeList = ::testing::Types; @@ -389,7 +440,6 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) { TYPED_TEST(NcclManagerTest, MultipleCallers) { const int num_ranks = 4; const int num_collectives_per_iteration = 10; - const int num_threads = num_ranks * 2; const int time_limit_micros = 1 * 1000 * 1000; // 1 second int64 start = Env::Default()->NowMicros(); @@ -417,8 +467,6 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) { std::mt19937(std::random_device()())); mutex mu; // guards case_and_rank. - std::unique_ptr pool( - new thread::ThreadPool(Env::Default(), "test", num_threads)); const int to_schedule = case_and_rank.size(); for (int i = 0; i < to_schedule; ++i) { auto fn = [&]() { @@ -446,9 +494,8 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) { /*communicator_key=*/""}, ncclSum); }; - pool->Schedule(fn); + this->work_queue_->Schedule(fn); } - pool.reset(); // wait for all work to be scheduled. VLOG(2) << "Verifying results for " << num_collectives_per_iteration << " collectives"; @@ -494,41 +541,24 @@ TYPED_TEST(NcclManagerTest, BasicAllGather) { // Test basic broadcast. TYPED_TEST(NcclManagerTest, BasicBroadcast) { - const int num_ranks = 4; - const int src_rank = 2; - for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) { - bool in_place = in_place_idx == 1; - std::unique_ptr test_case( - this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks, - TensorShape({5, 6}), /*src_node=*/0, - src_rank, in_place)); - for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; - auto* stream = device->tensorflow_gpu_device_info()->stream; - auto* input = rank == src_rank ? &test_case->ins[rank] : nullptr; - auto* output = test_case->outs[rank].NumElements() == 0 - ? nullptr - : &test_case->outs[rank]; - auto participant = absl::make_unique( - device->executor(), stream, event_mgr, device->gpu_id(), input, - output, rank, this->CreateDoneCallback(test_case.get())); - if (rank == src_rank) { - NcclManager::instance()->AddBroadcastSend( - std::move(participant), - {"broadcast", /*num_local_devices=*/num_ranks, - /*num_global_devices=*/num_ranks, - /*communicator_key=*/""}); - } else { - NcclManager::instance()->AddBroadcastRecv( - std::move(participant), - {"broadcast", /*num_local_devices=*/num_ranks, - /*num_global_devices=*/num_ranks, - /*communicator_key=*/""}); - } - } + this->RunBroadcastTest(/*num_ranks=*/4, /*src_rank=*/2, + /*in_place=*/false); +} - this->VerifyResults(test_case.get()); +// Test in-place broadcast. +TYPED_TEST(NcclManagerTest, InPlaceBroadcast) { + this->RunBroadcastTest(/*num_ranks=*/4, /*src_rank=*/1, + /*in_place=*/true); +} + +// Test broadcast with increasing ranks. +TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) { + for (int num_ranks = 4; num_ranks <= 8; ++num_ranks) { + const int src_rank = static_cast(random::New64() % num_ranks); + for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) { + const bool in_place = in_place_idx == 0; + this->RunBroadcastTest(num_ranks, src_rank, in_place); + } } } @@ -544,13 +574,13 @@ TEST(NcclManagerTest, CommunicatorKey) { // environment. It works on a single node and reuses GPUs. It enqueues NCCL // kernels on separate stream per rank. TYPED_TEST(NcclManagerTest, MultiNode) { - this->RunMultiNodeTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4); + this->RunMultiNodeAllReduceTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4); } // Tests that specifying `communicator_key` with a single node NCCL collective // works well. TYPED_TEST(NcclManagerTest, MultiNodeSingle) { - this->RunMultiNodeTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4); + this->RunMultiNodeAllReduceTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4); } // Checks that we return error status if a collective_key is used for different