Disable collective_nccl_test
on single GPU and enable on multiple GPUs.
PiperOrigin-RevId: 289142542 Change-Id: I6b9c41f74062accc32173cc7afa4228e500bf31c
This commit is contained in:
parent
f6404f4f24
commit
bdb99e06c5
@ -239,7 +239,13 @@ tf_cuda_cc_test(
|
|||||||
name = "collective_nccl_test",
|
name = "collective_nccl_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["collective_nccl_test.cc"],
|
srcs = ["collective_nccl_test.cc"],
|
||||||
tags = tf_cuda_tests_tags() + ["no_cuda_on_cpu_tap"],
|
tags = tf_cuda_tests_tags() + [
|
||||||
|
"guitar",
|
||||||
|
"manual",
|
||||||
|
"multi_gpu",
|
||||||
|
"no_oss",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:all_kernels",
|
"//tensorflow/core:all_kernels",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
|
@ -81,20 +81,18 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
class DeviceInstance;
|
class DeviceInstance;
|
||||||
|
|
||||||
NcclTestBase(CollectiveType collective_type, const string& collective_name)
|
NcclTestBase(CollectiveType collective_type, const string& collective_name)
|
||||||
: collective_type_(collective_type), collective_name_(collective_name) {}
|
: collective_type_(collective_type),
|
||||||
|
collective_name_(collective_name),
|
||||||
|
col_exec_(nullptr) {}
|
||||||
|
|
||||||
~NcclTestBase() override {
|
~NcclTestBase() override {
|
||||||
if (col_exec_) col_exec_->Unref();
|
if (col_exec_) col_exec_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitGPUDevices() {
|
void SetUp() {
|
||||||
std::vector<std::unique_ptr<Device>> all_devices;
|
std::vector<std::unique_ptr<Device>> all_devices;
|
||||||
SessionOptions session_options;
|
TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU)
|
||||||
session_options.config.mutable_gpu_options()
|
->AddDevices(SessionOptions(), "", &all_devices));
|
||||||
->set_per_process_gpu_memory_fraction(0.1);
|
|
||||||
session_options.env = Env::Default();
|
|
||||||
Status s = DeviceFactory::GetFactory(DEVICE_GPU)
|
|
||||||
->AddDevices(session_options, "", &all_devices);
|
|
||||||
TF_CHECK_OK(s);
|
|
||||||
for (std::unique_ptr<Device>& d : all_devices) {
|
for (std::unique_ptr<Device>& d : all_devices) {
|
||||||
if (d->device_type() == "GPU") {
|
if (d->device_type() == "GPU") {
|
||||||
gpus_.emplace_back(std::move(d));
|
gpus_.emplace_back(std::move(d));
|
||||||
@ -105,13 +103,11 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
void Init(const int num_ranks, const int instance_key) {
|
void Init(const int num_ranks, const int instance_key) {
|
||||||
setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
|
setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
|
||||||
setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
|
setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
|
||||||
InitGPUDevices();
|
|
||||||
std::vector<std::unique_ptr<Device>> local_devices;
|
std::vector<std::unique_ptr<Device>> local_devices;
|
||||||
std::vector<string> device_names;
|
std::vector<string> device_names;
|
||||||
|
CHECK_LE(num_ranks, gpus_.size());
|
||||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||||
if (rank < gpus_.size()) {
|
local_devices.emplace_back(std::move(gpus_[rank]));
|
||||||
local_devices.emplace_back(std::move(gpus_[rank]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
int num_gpus = local_devices.size();
|
int num_gpus = local_devices.size();
|
||||||
for (const auto& device : local_devices) {
|
for (const auto& device : local_devices) {
|
||||||
@ -180,6 +176,11 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RunTest(int num_ranks, int input_length, int instance_key) {
|
void RunTest(int num_ranks, int input_length, int instance_key) {
|
||||||
|
if (num_ranks > gpus_.size()) {
|
||||||
|
LOG(WARNING) << "Skipping test because required " << num_ranks
|
||||||
|
<< " GPUs but found " << gpus_.size();
|
||||||
|
return;
|
||||||
|
}
|
||||||
Init(num_ranks, instance_key);
|
Init(num_ranks, instance_key);
|
||||||
std::vector<float> expected;
|
std::vector<float> expected;
|
||||||
InitExpected(&expected, input_length, num_ranks);
|
InitExpected(&expected, input_length, num_ranks);
|
||||||
|
Loading…
Reference in New Issue
Block a user