From 1fa274b1e2220c2d4a712276d0ad1882a2120988 Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Tue, 4 Dec 2018 13:40:23 -0800 Subject: [PATCH] Better error checking and testing in NcclManager. After this change, we check the return value of every CUDA and NCCL call in NcclManager. If any call is unsuccessful, we call the NCCL callback with an error status. This change also re-enables NCCL tests. PiperOrigin-RevId: 224038066 --- tensorflow/core/nccl/BUILD | 38 ++++++------- tensorflow/core/nccl/nccl_manager.cc | 66 ++++++++++++++++------- tensorflow/core/nccl/nccl_manager.h | 8 ++- tensorflow/core/nccl/nccl_manager_test.cc | 49 +++++++++-------- 4 files changed, 93 insertions(+), 68 deletions(-) diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 50d9a2e8daa..4be33b2a0cf 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -11,6 +11,10 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_copts") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "nccl_lib", @@ -34,27 +38,17 @@ cc_library( tf_cuda_cc_test( name = "nccl_manager_test", size = "medium", - srcs = if_cuda( - [ - "nccl_manager_test.cc", - ], - [], - ), - # Disabled on jenkins until errors finding nvmlShutdown are found. - tags = [ - "manual", - "multi_gpu", - "no_oss", - "noguitar", - "notap", + srcs = ["nccl_manager_test.cc"], + tags = tf_cuda_tests_tags() + [ + "no_cuda_on_cpu_tap", # TODO(b/120284216): re-enable multi_gpu ], - deps = - if_cuda([ - ":nccl_lib", - "@local_config_nccl//:nccl", - "//tensorflow/core:cuda", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ]), + deps = [ + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_cuda([ + ":nccl_lib", + "@local_config_nccl//:nccl", + "//tensorflow/core:cuda", + ]), ) diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index f8e8c752227..df49bf1b976 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -24,6 +24,22 @@ limitations under the License. namespace tensorflow { +#define NCCL_RETURN_IF_ERROR(...) \ + do { \ + ncclResult_t nccl_status = (__VA_ARGS__); \ + if (nccl_status != ncclSuccess) { \ + return errors::Internal(ncclGetErrorString(nccl_status)); \ + } \ + } while (0) + +#define CUDA_RETURN_IF_ERROR(...) \ + do { \ + cudaError_t cuda_status = (__VA_ARGS__); \ + if (cuda_status != cudaSuccess) { \ + return errors::Internal(cudaGetErrorString(cuda_status)); \ + } \ + } while (0) + using se::cuda::ScopedActivateExecutorContext; // Contains data for a single stream used for nccl communication; this includes @@ -177,8 +193,8 @@ NcclManager* NcclManager::instance() { return instance; } -NcclManager::Communicator* NcclManager::GetCommunicator( - NcclManager::Collective* collective) { +Status NcclManager::GetCommunicator(NcclManager::Collective* collective, + NcclManager::Communicator** communicator) { // Sort by executor to make ordering of executors deterministic. std::sort(collective->participants.begin(), collective->participants.end(), [](const std::unique_ptr& a, @@ -217,7 +233,10 @@ NcclManager::Communicator* NcclManager::GetCommunicator( break; } } - if (i == num_devices) return comm.get(); + if (i == num_devices) { + *communicator = comm.get(); + return Status::OK(); + } } } @@ -264,37 +283,36 @@ NcclManager::Communicator* NcclManager::GetCommunicator( // NCCL2 prevents InitAll for more communicators than devices (but doesn't // check that device ids are unique). Work around it by initializing each // rank individually. - cudaGetDeviceCount(&device_count); + CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&device_count)); #endif std::vector nccl_comms(num_devices); if (num_devices <= device_count) { - auto result = - ncclCommInitAll(nccl_comms.data(), num_devices, devices.data()); - CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result); + NCCL_RETURN_IF_ERROR( + ncclCommInitAll(nccl_comms.data(), num_devices, devices.data())); } else { int savedDevice = 0; - CHECK_EQ(cudaGetDevice(&savedDevice), cudaSuccess); + CUDA_RETURN_IF_ERROR(cudaGetDevice(&savedDevice)); ncclUniqueId commId; - ncclGetUniqueId(&commId); + NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&commId)); #if NCCL_MAJOR >= 2 - CHECK_EQ(ncclGroupStart(), ncclSuccess); + NCCL_RETURN_IF_ERROR(ncclGroupStart()); #endif for (int rank = 0; rank < num_devices; ++rank) { - cudaSetDevice(devices[rank]); - auto result = - ncclCommInitRank(nccl_comms.data() + rank, num_devices, commId, rank); - CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result); + CUDA_RETURN_IF_ERROR(cudaSetDevice(devices[rank])); + NCCL_RETURN_IF_ERROR(ncclCommInitRank(nccl_comms.data() + rank, + num_devices, commId, rank)); } #if NCCL_MAJOR >= 2 - CHECK_EQ(ncclGroupEnd(), ncclSuccess); + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); #endif - cudaSetDevice(savedDevice); + CUDA_RETURN_IF_ERROR(cudaSetDevice(savedDevice)); } for (int rank = 0; rank < num_devices; ++rank) { members[rank].nccl_comm = nccl_comms[rank]; } communicators_.emplace_back(new Communicator(std::move(members))); - return communicators_.back().get(); + *communicator = communicators_.back().get(); + return Status::OK(); } void NcclManager::AddToAllReduce(int num_devices, const string& key, @@ -400,10 +418,18 @@ void NcclManager::AddParticipant(int num_devices, const string& key, void NcclManager::RunCollective(const string& key, Collective* collective) { static mutex collective_mu(LINKER_INITIALIZED); - auto* communicator = GetCommunicator(collective); - collective->communicator = communicator; - const int size = communicator->num_devices; + Communicator* communicator = nullptr; + const int size = static_cast(collective->participants.size()); + Status s = GetCommunicator(collective, &communicator); + if (!s.ok()) { + for (int i = 0; i < size; ++i) { + collective->participants[i]->done_callback(s); + } + delete collective; + return; + } + collective->communicator = communicator; for (int rank = 0; rank < size; ++rank) { Participant* p = collective->participants[rank].get(); NcclStream* nccl_stream = communicator->members[rank].nccl_stream; diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index 76b49101d47..5da4fe5554d 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -103,7 +103,13 @@ class NcclManager { struct NcclStream; struct Participant; - Communicator* GetCommunicator(Collective* collective); + // Gets the `Communicator` object that will be used to enqueue NCCL kernels + // for `collective`, and returns it via `communicator`. + // + // This may involve creating CUDA streams and NCCL initialization. If a NCCL + // or CUDA error occurs in the process, this returns an INTERNAL error with + // the corresponding NCCL/CUDA error string. + Status GetCommunicator(Collective* collective, Communicator** communicator); void AddParticipant(int num_devices, const string& key, std::unique_ptr participant, diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index dbc07865f0b..f43103e120b 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -28,8 +28,8 @@ limitations under the License. namespace tensorflow { -static std::vector GetGPUDevices() { - std::vector devices; +static std::vector> GetGPUDevices() { + std::vector> devices; SessionOptions session_options; session_options.config.mutable_gpu_options() ->set_per_process_gpu_memory_fraction(0.1); @@ -37,12 +37,12 @@ static std::vector GetGPUDevices() { Status s = DeviceFactory::GetFactory(DEVICE_GPU) ->AddDevices(session_options, "", &devices); TF_CHECK_OK(s); - std::vector gpus; - for (Device* d : devices) { - if (d->device_type() == "GPU") { - gpus.push_back(static_cast(d)); - } else { - delete d; + std::vector> gpus; + for (std::unique_ptr& device : devices) { + if (device->device_type() == "GPU") { + // If `device_type()` is GPU, this `Device` is guaranteed to be a + // `BaseGPUDevice`, which is a subclass of `Device`. + gpus.emplace_back(static_cast(device.release())); } } return gpus; @@ -64,16 +64,14 @@ class NcclManagerTest : public ::testing::Test { }; static void SetUpTestCase() { - setenv("NCCL_DEBUG", "INFO", 1 /* replace */); - devices_ = new std::vector(GetGPUDevices()); - CHECK(!devices_->empty()); + setenv("NCCL_DEBUG", "WARN", 1 /* replace */); + devices_ = new std::vector>(GetGPUDevices()); LOG(ERROR) << "Running test with " << devices_->size() << " gpus"; } - static void TearDownTestCase() { - for (auto device : *devices_) delete device; - delete devices_; - } + static int32 NumGPUs() { return static_cast(devices_->size()); } + + static void TearDownTestCase() { delete devices_; } TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op, TensorShape shape, float value_offset) { @@ -153,7 +151,7 @@ class NcclManagerTest : public ::testing::Test { stream->ThenMemcpy(out_cpu.flat().data(), out_gpu_mem, out_cpu.TotalBytes()); SE_ASSERT_OK(stream->BlockHostUntilDone()); - test::ExpectTensorNear(test_case->expected, out_cpu, 0.01); + test::ExpectClose(test_case->expected, out_cpu); } } @@ -166,7 +164,7 @@ class NcclManagerTest : public ::testing::Test { } static BaseGPUDevice* GetDevice(size_t rank) { - return devices_->at(rank % devices_->size()); + return devices_->at(rank % devices_->size()).get(); } private: @@ -181,13 +179,14 @@ class NcclManagerTest : public ::testing::Test { } private: - static std::vector* devices_; + static std::vector>* devices_; static const DataType data_type_; static const Scalar max_; }; template -std::vector* NcclManagerTest::devices_ = nullptr; +std::vector>* NcclManagerTest::devices_ = + nullptr; template const DataType NcclManagerTest::data_type_ = DataTypeToEnum::value; @@ -195,13 +194,13 @@ template const Scalar NcclManagerTest::max_ = Eigen::NumTraits::highest(); -// Instantiate tests for float and half. -using TypeList = ::testing::Types; +// Instantiate tests for float and double. +using TypeList = ::testing::Types; TYPED_TEST_CASE(NcclManagerTest, TypeList); // Test basic sum reduction. TYPED_TEST(NcclManagerTest, BasicSumReduction) { - const int num_ranks = 3; + const int num_ranks = this->NumGPUs(); for (int op = 0; op < 4; ++op) { ncclRedOp_t reduction_op = static_cast(op); @@ -230,10 +229,10 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) { // To test the higher settings, increase num_ranks, // num_collectives_per_iteration and time_limit_micros. TYPED_TEST(NcclManagerTest, MultipleCallers) { - const int num_ranks = 1; // 2; - const int num_collectives_per_iteration = 1; // 1000; + const int num_ranks = this->NumGPUs(); + const int num_collectives_per_iteration = 10; // 1000; const int num_threads = 3; - const int time_limit_micros = 1; // 60 * 30 * 1000 * 1000; + const int time_limit_micros = 100; // 60 * 30 * 1000 * 1000; int64 start = Env::Default()->NowMicros(); srand(Env::Default()->NowMicros());