diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 35157bad58f..487976bb012 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -52,11 +52,13 @@ tf_cuda_cc_test( size = "medium", srcs = ["nccl_manager_test.cc"], tags = tf_cuda_tests_tags() + [ - "no_cuda_on_cpu_tap", - # TODO(b/120284216): Add 'multi_gpu' tag and replace 'no_rocm' with 'rocm_multi_gpu'. - # The test fails on CUDA multi_gpu, and that tag also triggers on rocm_multi_gpu. - # The test also fails on ROCm unless 4 GPUs are used. + "guitar", + "manual", + "multi_gpu", + "no_oss", + # TODO(b/147451637): Replace 'no_rocm' with 'rocm_multi_gpu'. "no_rocm", + "notap", ], deps = [ "//tensorflow/core:test", diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index 8d4e48c9e33..fcbae5622d6 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -32,13 +32,8 @@ namespace tensorflow { static std::vector> GetGPUDevices() { std::vector> devices; - SessionOptions session_options; - session_options.config.mutable_gpu_options() - ->set_per_process_gpu_memory_fraction(0.1); - session_options.env = Env::Default(); - Status s = DeviceFactory::GetFactory(DEVICE_GPU) - ->AddDevices(session_options, "", &devices); - TF_CHECK_OK(s); + TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU) + ->AddDevices(SessionOptions(), "", &devices)); std::vector> gpus; for (std::unique_ptr& device : devices) { if (device->device_type() == "GPU") { @@ -55,9 +50,13 @@ class NcclManagerTest : public ::testing::Test { public: // A single all-reduce to apply. struct TestCase { + TestCase(int num_nodes, int num_ranks_per_node) + : num_nodes(num_nodes), num_ranks_per_node(num_ranks_per_node) {} std::vector ins; std::vector outs; Tensor expected; + const int num_nodes; + const int num_ranks_per_node; mutex mu; Status final_status; @@ -69,7 +68,10 @@ class NcclManagerTest : public ::testing::Test { setenv("NCCL_DEBUG", "INFO", 1 /* replace */); setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */); devices_ = new std::vector>(GetGPUDevices()); - LOG(INFO) << "Running test with " << devices_->size() << " gpus"; + VLOG(1) << "Running test with " << devices_->size() << " gpus"; + if (devices_->size() <= 1) { + LOG(FATAL) << "Cannot run NCCL test without multiple GPUs"; + } work_queue_ = new UnboundedWorkQueue(Env::Default(), "nccl_manager_test"); } @@ -80,6 +82,19 @@ class NcclManagerTest : public ::testing::Test { static int32 NumGPUs() { return static_cast(devices_->size()); } + // Let N = #GPUs. When N is even, num_nodes=2 and num_ranks_per_node=N/2. + // When N is odd, num_nodes=2 and num_ranks_per_node=(N-1)/2. + static void PopulateMultiNodeParams(int* num_nodes, int* num_ranks_per_node) { + const auto num_gpus = NumGPUs(); + CHECK_GT(num_gpus, 1); + *num_nodes = 2; + if (num_gpus % 2 == 0) { + *num_ranks_per_node = num_gpus / 2; + } else { + *num_ranks_per_node = (num_gpus - 1) / 2; + } + } + static void TearDownTestSuite() { delete devices_; delete work_queue_; @@ -88,7 +103,7 @@ class NcclManagerTest : public ::testing::Test { TestCase* MakeReductionTestCase(int num_nodes, int num_ranks_per_node, ncclRedOp_t reduction_op, TensorShape shape, float value_offset) { - TestCase* test_case = new TestCase(); + TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node); test_case->expected = Tensor(data_type_, shape); if (reduction_op == ncclProd) { test::FillFn(&test_case->expected, @@ -107,7 +122,7 @@ class NcclManagerTest : public ::testing::Test { float value_scale = 0.01; // Small scale to avoid fp16 overflow. for (int node = 0; node < num_nodes; ++node) { for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) { - auto* device = GetDevice(local_rank); + auto* device = GetDevice(num_ranks_per_node, node, local_rank); auto* stream = device->tensorflow_gpu_device_info()->stream; Tensor in_cpu(data_type_, shape); @@ -148,7 +163,7 @@ class NcclManagerTest : public ::testing::Test { TestCase* MakeGatherTestCase(int num_nodes, int num_ranks_per_node, TensorShape in_shape, TensorShape out_shape) { - TestCase* test_case = new TestCase(); + TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node); test_case->expected = Tensor(data_type_, out_shape); test::FillFn(&test_case->expected, [](int) { return static_cast(0); }); @@ -156,7 +171,7 @@ class NcclManagerTest : public ::testing::Test { float value_scale = 0.01; // Small scale to avoid fp16 overflow. for (int node = 0; node < num_nodes; ++node) { for (int i = 0; i < num_ranks_per_node; ++i) { - auto* device = GetDevice(i); + auto* device = GetDevice(num_ranks_per_node, node, i); auto* stream = device->tensorflow_gpu_device_info()->stream; Tensor in_cpu(data_type_, in_shape); @@ -194,14 +209,14 @@ class NcclManagerTest : public ::testing::Test { TestCase* MakeBroadcastTestCase(int num_nodes, int num_ranks_per_node, TensorShape shape, int src_node, int src_rank, bool in_place) { - TestCase* test_case = new TestCase(); + TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node); test_case->expected = Tensor(data_type_, shape); test::FillFn(&test_case->expected, [](int) { return static_cast(1); }); for (int node = 0; node < num_nodes; ++node) { for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) { - auto* device = GetDevice(local_rank); + auto* device = GetDevice(num_ranks_per_node, node, local_rank); if (node == src_node && local_rank == src_rank) { test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape); if (in_place) { @@ -240,19 +255,25 @@ class NcclManagerTest : public ::testing::Test { WaitForTestCompletion(test_case); TF_ASSERT_OK(test_case->final_status); // Copy memory to host and verify. - for (int rank = 0; rank < test_case->outs.size(); ++rank) { - auto* device = GetDevice(rank); - auto* stream = device->tensorflow_gpu_device_info()->stream; - const Tensor& out_gpu = test_case->outs[rank]; - Tensor out_cpu(data_type_, out_gpu.shape()); - auto out_gpu_mem = AsDeviceMemory(out_gpu.flat().data()); - stream->ThenMemcpy(out_cpu.flat().data(), out_gpu_mem, - out_cpu.TotalBytes()); - SE_ASSERT_OK(stream->BlockHostUntilDone()); - VLOG(1) << "Verifying rank " << rank << " expected shape " - << test_case->expected.shape() << " out shape " - << out_cpu.shape(); - test::ExpectClose(test_case->expected, out_cpu); + for (int node = 0; node < test_case->num_nodes; ++node) { + for (int local_rank = 0; local_rank < test_case->num_ranks_per_node; + ++local_rank) { + auto* device = + GetDevice(test_case->num_ranks_per_node, node, local_rank); + auto* stream = device->tensorflow_gpu_device_info()->stream; + const int global_rank = + GlobalRank(test_case->num_ranks_per_node, node, local_rank); + const Tensor& out_gpu = test_case->outs[global_rank]; + Tensor out_cpu(data_type_, out_gpu.shape()); + auto out_gpu_mem = AsDeviceMemory(out_gpu.flat().data()); + stream->ThenMemcpy(out_cpu.flat().data(), out_gpu_mem, + out_cpu.TotalBytes()); + SE_ASSERT_OK(stream->BlockHostUntilDone()); + VLOG(1) << "Verifying rank " << global_rank << " expected shape " + << test_case->expected.shape() << " out shape " + << out_cpu.shape(); + test::ExpectClose(test_case->expected, out_cpu); + } } } @@ -302,10 +323,11 @@ class NcclManagerTest : public ::testing::Test { reduction_op, &test_case] { for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) { - auto* device = this->GetDevice(local_rank); + auto* device = GetDevice(num_ranks_per_node, node, local_rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; - const int global_rank = node * num_ranks_per_node + local_rank; + const int global_rank = + GlobalRank(num_ranks_per_node, node, local_rank); auto participant = absl::make_unique( device->executor(), stream, info, &test_case->ins[global_rank], &test_case->outs[global_rank], global_rank, @@ -350,10 +372,11 @@ class NcclManagerTest : public ::testing::Test { auto rank_fn = [this, node, num_ranks_per_node, num_global_ranks, src_global_rank, local_rank, &node_states, &collective_key, &communicator_key, &test_case]() { - auto* device = this->GetDevice(local_rank); + auto* device = GetDevice(num_ranks_per_node, node, local_rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; - const int global_rank = node * num_ranks_per_node + local_rank; + const int global_rank = + GlobalRank(num_ranks_per_node, node, local_rank); auto* input = global_rank == src_global_rank ? &test_case->ins[global_rank] : nullptr; @@ -388,8 +411,15 @@ class NcclManagerTest : public ::testing::Test { this->VerifyResults(test_case.get()); } - static BaseGPUDevice* GetDevice(size_t rank) { - return devices_->at(rank % devices_->size()).get(); + static int GlobalRank(int num_ranks_per_node, int node, int local_rank) { + return node * num_ranks_per_node + local_rank; + } + + static BaseGPUDevice* GetDevice(int num_ranks_per_node, int node, + int local_rank) { + const int device_idx = GlobalRank(num_ranks_per_node, node, local_rank); + CHECK_LT(device_idx, devices_->size()); + return (*devices_)[device_idx].get(); } static UnboundedWorkQueue* work_queue_; @@ -428,7 +458,7 @@ TYPED_TEST_SUITE(NcclManagerTest, TypeList); // Test basic sum reduction. TYPED_TEST(NcclManagerTest, BasicSumReduction) { - const int num_ranks = 4; + const int num_ranks = this->NumGPUs(); for (int op = 0; op < 4; ++op) { ncclRedOp_t reduction_op = static_cast(op); @@ -436,7 +466,7 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) { this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, reduction_op, TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); VLOG(2) << "rank " << rank << " device " << device->name(); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; @@ -463,7 +493,7 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) { // To run test longer, increase num_ranks, num_collectives_per_iteration and // time_limit_micros. TYPED_TEST(NcclManagerTest, MultipleCallers) { - const int num_ranks = 4; + const int num_ranks = this->NumGPUs(); const int num_collectives_per_iteration = 10; const int time_limit_micros = 1 * 1000 * 1000; // 1 second @@ -483,7 +513,7 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) { } for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* stream = device->tensorflow_gpu_device_info()->stream; SE_ASSERT_OK(stream->BlockHostUntilDone()); } @@ -503,7 +533,7 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) { rank = case_and_rank.back().second; case_and_rank.pop_back(); } - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; typename TestFixture::TestCase* test_case = test_cases[test_num].get(); @@ -538,14 +568,14 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) { // Test basic all-gather. TYPED_TEST(NcclManagerTest, BasicAllGather) { - const int num_ranks = 4; + const int num_ranks = this->NumGPUs(); for (int i = 0; i < num_ranks; ++i) { std::unique_ptr test_case( this->MakeGatherTestCase(/*num_nodes=*/1, num_ranks, TensorShape({2, 3}), TensorShape({2 * num_ranks, 3}))); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); VLOG(2) << "rank " << rank << " device " << device->name(); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; @@ -567,26 +597,23 @@ TYPED_TEST(NcclManagerTest, BasicAllGather) { // Test basic broadcast. TYPED_TEST(NcclManagerTest, BasicBroadcast) { - this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, - /*src_node=*/0, /*src_local_rank=*/2, + this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, + /*num_ranks_per_node=*/this->NumGPUs(), + /*src_node=*/0, /*src_local_rank=*/0, /*in_place=*/false); } // Test in-place broadcast. TYPED_TEST(NcclManagerTest, InPlaceBroadcast) { - this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, - /*src_node=*/0, /*src_local_rank=*/1, + this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, + /*num_ranks_per_node=*/this->NumGPUs(), + /*src_node=*/0, /*src_local_rank=*/0, /*in_place=*/true); } // Test broadcast with increasing ranks. TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) { -#if TENSORFLOW_USE_ROCM - for (int num_ranks = 1; num_ranks <= 4; ++num_ranks) -#else - for (int num_ranks = 4; num_ranks <= 8; ++num_ranks) -#endif - { + for (int num_ranks = 1; num_ranks <= this->NumGPUs(); ++num_ranks) { const int src_rank = static_cast(random::New64() % num_ranks); for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) { const bool in_place = in_place_idx == 0; @@ -606,42 +633,49 @@ TEST(NcclManagerTest, CommunicatorKey) { #if !TENSORFLOW_USE_ROCM // This test creates `num_nodes` NcclManagers to simulate a multi-node -// environment. It works on a single node and reuses GPUs. It enqueues NCCL +// environment. It works on a single node with multiple GPUs. It enqueues NCCL // kernels on separate stream per rank. TYPED_TEST(NcclManagerTest, MultiNode) { - this->RunMultiNodeAllReduceTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4); + int num_nodes; + int num_ranks_per_node; + this->PopulateMultiNodeParams(&num_nodes, &num_ranks_per_node); + VLOG(1) << "Calling RunMultiNodeAllReduceTest with num_nodes=" << num_nodes + << " and num_ranks_per_node=" << num_ranks_per_node; + this->RunMultiNodeAllReduceTest(num_nodes, num_ranks_per_node); } #endif // Tests that specifying `communicator_key` with a single node NCCL collective // works well. TYPED_TEST(NcclManagerTest, MultiNodeSingle) { - this->RunMultiNodeAllReduceTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4); + this->RunMultiNodeAllReduceTest(/*num_nodes=*/1, + /*num_ranks_per_node=*/this->NumGPUs()); } +#if !TENSORFLOW_USE_ROCM // Multi-node broadcast. TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) { -#if TENSORFLOW_USE_ROCM - this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, - /*src_node=*/0, /*src_local_rank=*/3, - /*in_place=*/true); -#else - this->RunMultiNodeBroadcastTest(/*num_nodes=*/4, /*num_ranks_per_node=*/8, - /*src_node=*/2, /*src_local_rank=*/3, + int num_nodes; + int num_ranks_per_node; + this->PopulateMultiNodeParams(&num_nodes, &num_ranks_per_node); + VLOG(1) << "Calling RunMultiNodeBroadcastTest with num_nodes=" << num_nodes + << " and num_ranks_per_node=" << num_ranks_per_node; + this->RunMultiNodeBroadcastTest(num_nodes, num_ranks_per_node, + /*src_node=*/0, /*src_local_rank=*/0, /*in_place=*/true); #endif } // Checks that we return error status if a collective_key is used for different -// types of collectives, e.g. a reduction and a broadcast. +// types of collectives, e.g.a reduction and a broadcast. TYPED_TEST(NcclManagerTest, ConsistentCollectiveType) { const int num_ranks = 2; std::unique_ptr test_case( - this->MakeReductionTestCase(1 /* num_nodes */, num_ranks, ncclSum, + this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum, TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique( @@ -675,10 +709,10 @@ TYPED_TEST(NcclManagerTest, ConsistentCommunicatorKey) { const int num_ranks = 2; std::unique_ptr test_case( - this->MakeReductionTestCase(1 /* num_nodes */, num_ranks, ncclSum, + this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum, TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique( @@ -704,10 +738,10 @@ TYPED_TEST(NcclManagerTest, ConsistentNumberOfDevices) { const int num_ranks = 2; std::unique_ptr test_case( - this->MakeReductionTestCase(1 /* num_nodes */, num_ranks, ncclSum, + this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum, TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; int num_devices = rank == 0 ? num_ranks : num_ranks + 1; @@ -736,7 +770,7 @@ TYPED_TEST(NcclManagerTest, BroadcastNoSource) { TensorShape({2, 3}), /*src_node=*/-1, /*src_rank=*/-1, false)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique( @@ -762,7 +796,7 @@ TYPED_TEST(NcclManagerTest, BroadcastMultipleSends) { TensorShape({2, 3}), /*src_node=*/-1, /*src_rank=*/-1, false)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique( @@ -790,7 +824,7 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) { TensorShape({2, 3}), /*src_node=*/-1, /*src_rank=*/-1, false)); for (int rank = 0; rank < num_ranks; ++rank) { - auto* device = this->GetDevice(rank); + auto* device = this->GetDevice(num_ranks, /*node=*/0, rank); auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique(