Disable collective_nccl_test
on single GPU and enable on multiple GPUs.
PiperOrigin-RevId: 289142542 Change-Id: I6b9c41f74062accc32173cc7afa4228e500bf31c
This commit is contained in:
parent
f6404f4f24
commit
bdb99e06c5
tensorflow/core/kernels
@ -239,7 +239,13 @@ tf_cuda_cc_test(
|
||||
name = "collective_nccl_test",
|
||||
size = "small",
|
||||
srcs = ["collective_nccl_test.cc"],
|
||||
tags = tf_cuda_tests_tags() + ["no_cuda_on_cpu_tap"],
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"guitar",
|
||||
"manual",
|
||||
"multi_gpu",
|
||||
"no_oss",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -81,20 +81,18 @@ class NcclTestBase : public ::testing::Test {
|
||||
class DeviceInstance;
|
||||
|
||||
NcclTestBase(CollectiveType collective_type, const string& collective_name)
|
||||
: collective_type_(collective_type), collective_name_(collective_name) {}
|
||||
: collective_type_(collective_type),
|
||||
collective_name_(collective_name),
|
||||
col_exec_(nullptr) {}
|
||||
|
||||
~NcclTestBase() override {
|
||||
if (col_exec_) col_exec_->Unref();
|
||||
}
|
||||
|
||||
void InitGPUDevices() {
|
||||
void SetUp() {
|
||||
std::vector<std::unique_ptr<Device>> all_devices;
|
||||
SessionOptions session_options;
|
||||
session_options.config.mutable_gpu_options()
|
||||
->set_per_process_gpu_memory_fraction(0.1);
|
||||
session_options.env = Env::Default();
|
||||
Status s = DeviceFactory::GetFactory(DEVICE_GPU)
|
||||
->AddDevices(session_options, "", &all_devices);
|
||||
TF_CHECK_OK(s);
|
||||
TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU)
|
||||
->AddDevices(SessionOptions(), "", &all_devices));
|
||||
for (std::unique_ptr<Device>& d : all_devices) {
|
||||
if (d->device_type() == "GPU") {
|
||||
gpus_.emplace_back(std::move(d));
|
||||
@ -105,13 +103,11 @@ class NcclTestBase : public ::testing::Test {
|
||||
void Init(const int num_ranks, const int instance_key) {
|
||||
setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
|
||||
setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
|
||||
InitGPUDevices();
|
||||
std::vector<std::unique_ptr<Device>> local_devices;
|
||||
std::vector<string> device_names;
|
||||
CHECK_LE(num_ranks, gpus_.size());
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
if (rank < gpus_.size()) {
|
||||
local_devices.emplace_back(std::move(gpus_[rank]));
|
||||
}
|
||||
local_devices.emplace_back(std::move(gpus_[rank]));
|
||||
}
|
||||
int num_gpus = local_devices.size();
|
||||
for (const auto& device : local_devices) {
|
||||
@ -180,6 +176,11 @@ class NcclTestBase : public ::testing::Test {
|
||||
}
|
||||
|
||||
void RunTest(int num_ranks, int input_length, int instance_key) {
|
||||
if (num_ranks > gpus_.size()) {
|
||||
LOG(WARNING) << "Skipping test because required " << num_ranks
|
||||
<< " GPUs but found " << gpus_.size();
|
||||
return;
|
||||
}
|
||||
Init(num_ranks, instance_key);
|
||||
std::vector<float> expected;
|
||||
InitExpected(&expected, input_length, num_ranks);
|
||||
|
Loading…
Reference in New Issue
Block a user