Disable nccl_manager_test on single GPU and re-enable with multiple GPUs.

This change modifies `nccl_manager_test` so that it runs with multiple physical
GPUs.  The main changes are to pick the number of nodes and ranks based on the
actual devices available.

PiperOrigin-RevId: 289146110
Change-Id: I5d06ac39eee3ffe69311194485fc64974bc5410f
This commit is contained in:
Ayush Dubey 2020-01-10 12:46:04 -08:00 committed by TensorFlower Gardener
parent bdb99e06c5
commit d65a5f1bdf
2 changed files with 110 additions and 74 deletions

View File

@ -52,11 +52,13 @@ tf_cuda_cc_test(
size = "medium", size = "medium",
srcs = ["nccl_manager_test.cc"], srcs = ["nccl_manager_test.cc"],
tags = tf_cuda_tests_tags() + [ tags = tf_cuda_tests_tags() + [
"no_cuda_on_cpu_tap", "guitar",
# TODO(b/120284216): Add 'multi_gpu' tag and replace 'no_rocm' with 'rocm_multi_gpu'. "manual",
# The test fails on CUDA multi_gpu, and that tag also triggers on rocm_multi_gpu. "multi_gpu",
# The test also fails on ROCm unless 4 GPUs are used. "no_oss",
# TODO(b/147451637): Replace 'no_rocm' with 'rocm_multi_gpu'.
"no_rocm", "no_rocm",
"notap",
], ],
deps = [ deps = [
"//tensorflow/core:test", "//tensorflow/core:test",

View File

@ -32,13 +32,8 @@ namespace tensorflow {
static std::vector<std::unique_ptr<BaseGPUDevice>> GetGPUDevices() { static std::vector<std::unique_ptr<BaseGPUDevice>> GetGPUDevices() {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
SessionOptions session_options; TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU)
session_options.config.mutable_gpu_options() ->AddDevices(SessionOptions(), "", &devices));
->set_per_process_gpu_memory_fraction(0.1);
session_options.env = Env::Default();
Status s = DeviceFactory::GetFactory(DEVICE_GPU)
->AddDevices(session_options, "", &devices);
TF_CHECK_OK(s);
std::vector<std::unique_ptr<BaseGPUDevice>> gpus; std::vector<std::unique_ptr<BaseGPUDevice>> gpus;
for (std::unique_ptr<Device>& device : devices) { for (std::unique_ptr<Device>& device : devices) {
if (device->device_type() == "GPU") { if (device->device_type() == "GPU") {
@ -55,9 +50,13 @@ class NcclManagerTest : public ::testing::Test {
public: public:
// A single all-reduce to apply. // A single all-reduce to apply.
struct TestCase { struct TestCase {
TestCase(int num_nodes, int num_ranks_per_node)
: num_nodes(num_nodes), num_ranks_per_node(num_ranks_per_node) {}
std::vector<Tensor> ins; std::vector<Tensor> ins;
std::vector<Tensor> outs; std::vector<Tensor> outs;
Tensor expected; Tensor expected;
const int num_nodes;
const int num_ranks_per_node;
mutex mu; mutex mu;
Status final_status; Status final_status;
@ -69,7 +68,10 @@ class NcclManagerTest : public ::testing::Test {
setenv("NCCL_DEBUG", "INFO", 1 /* replace */); setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */); setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
devices_ = new std::vector<std::unique_ptr<BaseGPUDevice>>(GetGPUDevices()); devices_ = new std::vector<std::unique_ptr<BaseGPUDevice>>(GetGPUDevices());
LOG(INFO) << "Running test with " << devices_->size() << " gpus"; VLOG(1) << "Running test with " << devices_->size() << " gpus";
if (devices_->size() <= 1) {
LOG(FATAL) << "Cannot run NCCL test without multiple GPUs";
}
work_queue_ = new UnboundedWorkQueue(Env::Default(), "nccl_manager_test"); work_queue_ = new UnboundedWorkQueue(Env::Default(), "nccl_manager_test");
} }
@ -80,6 +82,19 @@ class NcclManagerTest : public ::testing::Test {
static int32 NumGPUs() { return static_cast<int32>(devices_->size()); } static int32 NumGPUs() { return static_cast<int32>(devices_->size()); }
// Let N = #GPUs. When N is even, num_nodes=2 and num_ranks_per_node=N/2.
// When N is odd, num_nodes=2 and num_ranks_per_node=(N-1)/2.
static void PopulateMultiNodeParams(int* num_nodes, int* num_ranks_per_node) {
const auto num_gpus = NumGPUs();
CHECK_GT(num_gpus, 1);
*num_nodes = 2;
if (num_gpus % 2 == 0) {
*num_ranks_per_node = num_gpus / 2;
} else {
*num_ranks_per_node = (num_gpus - 1) / 2;
}
}
static void TearDownTestSuite() { static void TearDownTestSuite() {
delete devices_; delete devices_;
delete work_queue_; delete work_queue_;
@ -88,7 +103,7 @@ class NcclManagerTest : public ::testing::Test {
TestCase* MakeReductionTestCase(int num_nodes, int num_ranks_per_node, TestCase* MakeReductionTestCase(int num_nodes, int num_ranks_per_node,
ncclRedOp_t reduction_op, TensorShape shape, ncclRedOp_t reduction_op, TensorShape shape,
float value_offset) { float value_offset) {
TestCase* test_case = new TestCase(); TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node);
test_case->expected = Tensor(data_type_, shape); test_case->expected = Tensor(data_type_, shape);
if (reduction_op == ncclProd) { if (reduction_op == ncclProd) {
test::FillFn<Scalar>(&test_case->expected, test::FillFn<Scalar>(&test_case->expected,
@ -107,7 +122,7 @@ class NcclManagerTest : public ::testing::Test {
float value_scale = 0.01; // Small scale to avoid fp16 overflow. float value_scale = 0.01; // Small scale to avoid fp16 overflow.
for (int node = 0; node < num_nodes; ++node) { for (int node = 0; node < num_nodes; ++node) {
for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) { for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) {
auto* device = GetDevice(local_rank); auto* device = GetDevice(num_ranks_per_node, node, local_rank);
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
Tensor in_cpu(data_type_, shape); Tensor in_cpu(data_type_, shape);
@ -148,7 +163,7 @@ class NcclManagerTest : public ::testing::Test {
TestCase* MakeGatherTestCase(int num_nodes, int num_ranks_per_node, TestCase* MakeGatherTestCase(int num_nodes, int num_ranks_per_node,
TensorShape in_shape, TensorShape out_shape) { TensorShape in_shape, TensorShape out_shape) {
TestCase* test_case = new TestCase(); TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node);
test_case->expected = Tensor(data_type_, out_shape); test_case->expected = Tensor(data_type_, out_shape);
test::FillFn<Scalar>(&test_case->expected, test::FillFn<Scalar>(&test_case->expected,
[](int) { return static_cast<Scalar>(0); }); [](int) { return static_cast<Scalar>(0); });
@ -156,7 +171,7 @@ class NcclManagerTest : public ::testing::Test {
float value_scale = 0.01; // Small scale to avoid fp16 overflow. float value_scale = 0.01; // Small scale to avoid fp16 overflow.
for (int node = 0; node < num_nodes; ++node) { for (int node = 0; node < num_nodes; ++node) {
for (int i = 0; i < num_ranks_per_node; ++i) { for (int i = 0; i < num_ranks_per_node; ++i) {
auto* device = GetDevice(i); auto* device = GetDevice(num_ranks_per_node, node, i);
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
Tensor in_cpu(data_type_, in_shape); Tensor in_cpu(data_type_, in_shape);
@ -194,14 +209,14 @@ class NcclManagerTest : public ::testing::Test {
TestCase* MakeBroadcastTestCase(int num_nodes, int num_ranks_per_node, TestCase* MakeBroadcastTestCase(int num_nodes, int num_ranks_per_node,
TensorShape shape, int src_node, int src_rank, TensorShape shape, int src_node, int src_rank,
bool in_place) { bool in_place) {
TestCase* test_case = new TestCase(); TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node);
test_case->expected = Tensor(data_type_, shape); test_case->expected = Tensor(data_type_, shape);
test::FillFn<Scalar>(&test_case->expected, test::FillFn<Scalar>(&test_case->expected,
[](int) { return static_cast<Scalar>(1); }); [](int) { return static_cast<Scalar>(1); });
for (int node = 0; node < num_nodes; ++node) { for (int node = 0; node < num_nodes; ++node) {
for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) { for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) {
auto* device = GetDevice(local_rank); auto* device = GetDevice(num_ranks_per_node, node, local_rank);
if (node == src_node && local_rank == src_rank) { if (node == src_node && local_rank == src_rank) {
test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape); test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape);
if (in_place) { if (in_place) {
@ -240,19 +255,25 @@ class NcclManagerTest : public ::testing::Test {
WaitForTestCompletion(test_case); WaitForTestCompletion(test_case);
TF_ASSERT_OK(test_case->final_status); TF_ASSERT_OK(test_case->final_status);
// Copy memory to host and verify. // Copy memory to host and verify.
for (int rank = 0; rank < test_case->outs.size(); ++rank) { for (int node = 0; node < test_case->num_nodes; ++node) {
auto* device = GetDevice(rank); for (int local_rank = 0; local_rank < test_case->num_ranks_per_node;
auto* stream = device->tensorflow_gpu_device_info()->stream; ++local_rank) {
const Tensor& out_gpu = test_case->outs[rank]; auto* device =
Tensor out_cpu(data_type_, out_gpu.shape()); GetDevice(test_case->num_ranks_per_node, node, local_rank);
auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<Scalar>().data()); auto* stream = device->tensorflow_gpu_device_info()->stream;
stream->ThenMemcpy(out_cpu.flat<Scalar>().data(), out_gpu_mem, const int global_rank =
out_cpu.TotalBytes()); GlobalRank(test_case->num_ranks_per_node, node, local_rank);
SE_ASSERT_OK(stream->BlockHostUntilDone()); const Tensor& out_gpu = test_case->outs[global_rank];
VLOG(1) << "Verifying rank " << rank << " expected shape " Tensor out_cpu(data_type_, out_gpu.shape());
<< test_case->expected.shape() << " out shape " auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<Scalar>().data());
<< out_cpu.shape(); stream->ThenMemcpy(out_cpu.flat<Scalar>().data(), out_gpu_mem,
test::ExpectClose(test_case->expected, out_cpu); out_cpu.TotalBytes());
SE_ASSERT_OK(stream->BlockHostUntilDone());
VLOG(1) << "Verifying rank " << global_rank << " expected shape "
<< test_case->expected.shape() << " out shape "
<< out_cpu.shape();
test::ExpectClose(test_case->expected, out_cpu);
}
} }
} }
@ -302,10 +323,11 @@ class NcclManagerTest : public ::testing::Test {
reduction_op, &test_case] { reduction_op, &test_case] {
for (int local_rank = 0; local_rank < num_ranks_per_node; for (int local_rank = 0; local_rank < num_ranks_per_node;
++local_rank) { ++local_rank) {
auto* device = this->GetDevice(local_rank); auto* device = GetDevice(num_ranks_per_node, node, local_rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
const int global_rank = node * num_ranks_per_node + local_rank; const int global_rank =
GlobalRank(num_ranks_per_node, node, local_rank);
auto participant = absl::make_unique<NcclManager::Participant>( auto participant = absl::make_unique<NcclManager::Participant>(
device->executor(), stream, info, &test_case->ins[global_rank], device->executor(), stream, info, &test_case->ins[global_rank],
&test_case->outs[global_rank], global_rank, &test_case->outs[global_rank], global_rank,
@ -350,10 +372,11 @@ class NcclManagerTest : public ::testing::Test {
auto rank_fn = [this, node, num_ranks_per_node, num_global_ranks, auto rank_fn = [this, node, num_ranks_per_node, num_global_ranks,
src_global_rank, local_rank, &node_states, src_global_rank, local_rank, &node_states,
&collective_key, &communicator_key, &test_case]() { &collective_key, &communicator_key, &test_case]() {
auto* device = this->GetDevice(local_rank); auto* device = GetDevice(num_ranks_per_node, node, local_rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
const int global_rank = node * num_ranks_per_node + local_rank; const int global_rank =
GlobalRank(num_ranks_per_node, node, local_rank);
auto* input = global_rank == src_global_rank auto* input = global_rank == src_global_rank
? &test_case->ins[global_rank] ? &test_case->ins[global_rank]
: nullptr; : nullptr;
@ -388,8 +411,15 @@ class NcclManagerTest : public ::testing::Test {
this->VerifyResults(test_case.get()); this->VerifyResults(test_case.get());
} }
static BaseGPUDevice* GetDevice(size_t rank) { static int GlobalRank(int num_ranks_per_node, int node, int local_rank) {
return devices_->at(rank % devices_->size()).get(); return node * num_ranks_per_node + local_rank;
}
static BaseGPUDevice* GetDevice(int num_ranks_per_node, int node,
int local_rank) {
const int device_idx = GlobalRank(num_ranks_per_node, node, local_rank);
CHECK_LT(device_idx, devices_->size());
return (*devices_)[device_idx].get();
} }
static UnboundedWorkQueue* work_queue_; static UnboundedWorkQueue* work_queue_;
@ -428,7 +458,7 @@ TYPED_TEST_SUITE(NcclManagerTest, TypeList);
// Test basic sum reduction. // Test basic sum reduction.
TYPED_TEST(NcclManagerTest, BasicSumReduction) { TYPED_TEST(NcclManagerTest, BasicSumReduction) {
const int num_ranks = 4; const int num_ranks = this->NumGPUs();
for (int op = 0; op < 4; ++op) { for (int op = 0; op < 4; ++op) {
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op); ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
@ -436,7 +466,7 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) {
this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, reduction_op, this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, reduction_op,
TensorShape({2, 3}), 0.0f)); TensorShape({2, 3}), 0.0f));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
VLOG(2) << "rank " << rank << " device " << device->name(); VLOG(2) << "rank " << rank << " device " << device->name();
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
@ -463,7 +493,7 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) {
// To run test longer, increase num_ranks, num_collectives_per_iteration and // To run test longer, increase num_ranks, num_collectives_per_iteration and
// time_limit_micros. // time_limit_micros.
TYPED_TEST(NcclManagerTest, MultipleCallers) { TYPED_TEST(NcclManagerTest, MultipleCallers) {
const int num_ranks = 4; const int num_ranks = this->NumGPUs();
const int num_collectives_per_iteration = 10; const int num_collectives_per_iteration = 10;
const int time_limit_micros = 1 * 1000 * 1000; // 1 second const int time_limit_micros = 1 * 1000 * 1000; // 1 second
@ -483,7 +513,7 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) {
} }
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
SE_ASSERT_OK(stream->BlockHostUntilDone()); SE_ASSERT_OK(stream->BlockHostUntilDone());
} }
@ -503,7 +533,7 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) {
rank = case_and_rank.back().second; rank = case_and_rank.back().second;
case_and_rank.pop_back(); case_and_rank.pop_back();
} }
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
typename TestFixture::TestCase* test_case = test_cases[test_num].get(); typename TestFixture::TestCase* test_case = test_cases[test_num].get();
@ -538,14 +568,14 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) {
// Test basic all-gather. // Test basic all-gather.
TYPED_TEST(NcclManagerTest, BasicAllGather) { TYPED_TEST(NcclManagerTest, BasicAllGather) {
const int num_ranks = 4; const int num_ranks = this->NumGPUs();
for (int i = 0; i < num_ranks; ++i) { for (int i = 0; i < num_ranks; ++i) {
std::unique_ptr<typename TestFixture::TestCase> test_case( std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeGatherTestCase(/*num_nodes=*/1, num_ranks, this->MakeGatherTestCase(/*num_nodes=*/1, num_ranks,
TensorShape({2, 3}), TensorShape({2, 3}),
TensorShape({2 * num_ranks, 3}))); TensorShape({2 * num_ranks, 3})));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
VLOG(2) << "rank " << rank << " device " << device->name(); VLOG(2) << "rank " << rank << " device " << device->name();
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
@ -567,26 +597,23 @@ TYPED_TEST(NcclManagerTest, BasicAllGather) {
// Test basic broadcast. // Test basic broadcast.
TYPED_TEST(NcclManagerTest, BasicBroadcast) { TYPED_TEST(NcclManagerTest, BasicBroadcast) {
this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, this->RunMultiNodeBroadcastTest(/*num_nodes=*/1,
/*src_node=*/0, /*src_local_rank=*/2, /*num_ranks_per_node=*/this->NumGPUs(),
/*src_node=*/0, /*src_local_rank=*/0,
/*in_place=*/false); /*in_place=*/false);
} }
// Test in-place broadcast. // Test in-place broadcast.
TYPED_TEST(NcclManagerTest, InPlaceBroadcast) { TYPED_TEST(NcclManagerTest, InPlaceBroadcast) {
this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, this->RunMultiNodeBroadcastTest(/*num_nodes=*/1,
/*src_node=*/0, /*src_local_rank=*/1, /*num_ranks_per_node=*/this->NumGPUs(),
/*src_node=*/0, /*src_local_rank=*/0,
/*in_place=*/true); /*in_place=*/true);
} }
// Test broadcast with increasing ranks. // Test broadcast with increasing ranks.
TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) { TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) {
#if TENSORFLOW_USE_ROCM for (int num_ranks = 1; num_ranks <= this->NumGPUs(); ++num_ranks) {
for (int num_ranks = 1; num_ranks <= 4; ++num_ranks)
#else
for (int num_ranks = 4; num_ranks <= 8; ++num_ranks)
#endif
{
const int src_rank = static_cast<int>(random::New64() % num_ranks); const int src_rank = static_cast<int>(random::New64() % num_ranks);
for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) { for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) {
const bool in_place = in_place_idx == 0; const bool in_place = in_place_idx == 0;
@ -606,42 +633,49 @@ TEST(NcclManagerTest, CommunicatorKey) {
#if !TENSORFLOW_USE_ROCM #if !TENSORFLOW_USE_ROCM
// This test creates `num_nodes` NcclManagers to simulate a multi-node // This test creates `num_nodes` NcclManagers to simulate a multi-node
// environment. It works on a single node and reuses GPUs. It enqueues NCCL // environment. It works on a single node with multiple GPUs. It enqueues NCCL
// kernels on separate stream per rank. // kernels on separate stream per rank.
TYPED_TEST(NcclManagerTest, MultiNode) { TYPED_TEST(NcclManagerTest, MultiNode) {
this->RunMultiNodeAllReduceTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4); int num_nodes;
int num_ranks_per_node;
this->PopulateMultiNodeParams(&num_nodes, &num_ranks_per_node);
VLOG(1) << "Calling RunMultiNodeAllReduceTest with num_nodes=" << num_nodes
<< " and num_ranks_per_node=" << num_ranks_per_node;
this->RunMultiNodeAllReduceTest(num_nodes, num_ranks_per_node);
} }
#endif #endif
// Tests that specifying `communicator_key` with a single node NCCL collective // Tests that specifying `communicator_key` with a single node NCCL collective
// works well. // works well.
TYPED_TEST(NcclManagerTest, MultiNodeSingle) { TYPED_TEST(NcclManagerTest, MultiNodeSingle) {
this->RunMultiNodeAllReduceTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4); this->RunMultiNodeAllReduceTest(/*num_nodes=*/1,
/*num_ranks_per_node=*/this->NumGPUs());
} }
#if !TENSORFLOW_USE_ROCM
// Multi-node broadcast. // Multi-node broadcast.
TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) { TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) {
#if TENSORFLOW_USE_ROCM int num_nodes;
this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, int num_ranks_per_node;
/*src_node=*/0, /*src_local_rank=*/3, this->PopulateMultiNodeParams(&num_nodes, &num_ranks_per_node);
/*in_place=*/true); VLOG(1) << "Calling RunMultiNodeBroadcastTest with num_nodes=" << num_nodes
#else << " and num_ranks_per_node=" << num_ranks_per_node;
this->RunMultiNodeBroadcastTest(/*num_nodes=*/4, /*num_ranks_per_node=*/8, this->RunMultiNodeBroadcastTest(num_nodes, num_ranks_per_node,
/*src_node=*/2, /*src_local_rank=*/3, /*src_node=*/0, /*src_local_rank=*/0,
/*in_place=*/true); /*in_place=*/true);
#endif #endif
} }
// Checks that we return error status if a collective_key is used for different // Checks that we return error status if a collective_key is used for different
// types of collectives, e.g. a reduction and a broadcast. // types of collectives, e.g.a reduction and a broadcast.
TYPED_TEST(NcclManagerTest, ConsistentCollectiveType) { TYPED_TEST(NcclManagerTest, ConsistentCollectiveType) {
const int num_ranks = 2; const int num_ranks = 2;
std::unique_ptr<typename TestFixture::TestCase> test_case( std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeReductionTestCase(1 /* num_nodes */, num_ranks, ncclSum, this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum,
TensorShape({2, 3}), 0.0f)); TensorShape({2, 3}), 0.0f));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
auto participant = absl::make_unique<NcclManager::Participant>( auto participant = absl::make_unique<NcclManager::Participant>(
@ -675,10 +709,10 @@ TYPED_TEST(NcclManagerTest, ConsistentCommunicatorKey) {
const int num_ranks = 2; const int num_ranks = 2;
std::unique_ptr<typename TestFixture::TestCase> test_case( std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeReductionTestCase(1 /* num_nodes */, num_ranks, ncclSum, this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum,
TensorShape({2, 3}), 0.0f)); TensorShape({2, 3}), 0.0f));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
auto participant = absl::make_unique<NcclManager::Participant>( auto participant = absl::make_unique<NcclManager::Participant>(
@ -704,10 +738,10 @@ TYPED_TEST(NcclManagerTest, ConsistentNumberOfDevices) {
const int num_ranks = 2; const int num_ranks = 2;
std::unique_ptr<typename TestFixture::TestCase> test_case( std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeReductionTestCase(1 /* num_nodes */, num_ranks, ncclSum, this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum,
TensorShape({2, 3}), 0.0f)); TensorShape({2, 3}), 0.0f));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
int num_devices = rank == 0 ? num_ranks : num_ranks + 1; int num_devices = rank == 0 ? num_ranks : num_ranks + 1;
@ -736,7 +770,7 @@ TYPED_TEST(NcclManagerTest, BroadcastNoSource) {
TensorShape({2, 3}), /*src_node=*/-1, TensorShape({2, 3}), /*src_node=*/-1,
/*src_rank=*/-1, false)); /*src_rank=*/-1, false));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
auto participant = absl::make_unique<NcclManager::Participant>( auto participant = absl::make_unique<NcclManager::Participant>(
@ -762,7 +796,7 @@ TYPED_TEST(NcclManagerTest, BroadcastMultipleSends) {
TensorShape({2, 3}), /*src_node=*/-1, TensorShape({2, 3}), /*src_node=*/-1,
/*src_rank=*/-1, false)); /*src_rank=*/-1, false));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
auto participant = absl::make_unique<NcclManager::Participant>( auto participant = absl::make_unique<NcclManager::Participant>(
@ -790,7 +824,7 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) {
TensorShape({2, 3}), /*src_node=*/-1, TensorShape({2, 3}), /*src_node=*/-1,
/*src_rank=*/-1, false)); /*src_rank=*/-1, false));
for (int rank = 0; rank < num_ranks; ++rank) { for (int rank = 0; rank < num_ranks; ++rank) {
auto* device = this->GetDevice(rank); auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
auto* info = device->tensorflow_gpu_device_info(); auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream; auto* stream = device->tensorflow_gpu_device_info()->stream;
auto participant = absl::make_unique<NcclManager::Participant>( auto participant = absl::make_unique<NcclManager::Participant>(