Enable resetting NcclManager after a previous StartAbort.

After calling Reset, the NcclManager is again available to launch collectives.

PiperOrigin-RevId: 338296784
Change-Id: I4ebc0fedbd84318ba08a1571e2a1c91e765f6b5a
This commit is contained in:
Ayush Dubey 2020-10-21 10:55:40 -07:00 committed by TensorFlower Gardener
parent 259ffa9ea6
commit 1c7803fa56
3 changed files with 52 additions and 12 deletions

View File

@ -739,6 +739,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
VLOG(2) << "call NcclAllReduce collective_key "
<< collective->collective_key << " participant " << p_idx
<< " num_participants " << collective->participants.size()
<< " sendbuff " << sendbuff << " recvbuff " << recvbuff
<< " nccl_comm " << nccl_comm << " comm_stream " << comm_stream
<< " cuda_stream " << cu_stream;
@ -849,7 +850,6 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
}
void NcclManager::StartAbort(const Status& s) {
VLOG(1) << "NcclManager StartAbort";
absl::flat_hash_map<string, Collective*> collectives;
std::vector<std::unique_ptr<Communicator>> communicators;
{
@ -864,6 +864,9 @@ void NcclManager::StartAbort(const Status& s) {
collectives.swap(collectives_);
communicators.swap(communicators_);
}
VLOG(2) << "Aborted NcclManager " << this << " with " << collectives.size()
<< " collectives and " << communicators.size()
<< " comms with status " << s;
// collectives_ contains pending launches that haven't been dispatched to
// kernel launch threads, so we can simply invoke the done callbacks of them.
for (const auto& item : collectives) {
@ -895,6 +898,12 @@ void NcclManager::StartAbort(const Status& s) {
pending.Wait();
}
void NcclManager::Reset() {
mutex_lock l(mu_);
status_ = Status();
VLOG(2) << "Reset NcclManager " << this;
}
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -195,6 +195,10 @@ class NcclManager {
// launched with this NcclManager.
void StartAbort(const Status& s);
// Resets a previously aborted NcclManager, making it available for future
// collectives.
void Reset();
private:
enum CollectiveType {
kAllReduce = 1,

View File

@ -17,8 +17,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/nccl/nccl_manager.h"
#include <algorithm>
#include <random>
#include <vector>
@ -27,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/nccl/nccl_manager.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/unbounded_work_queue.h"
@ -863,20 +862,19 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) {
// environment, on a single node with multiple GPUS. So tests that rely
// upon such simulation need to be skipped on the ROCm platform
TYPED_TEST(NcclManagerTest, Abort) {
TYPED_TEST(NcclManagerTest, AbortThenReset) {
using NodeState = typename TestFixture::NodeState;
using TestCase = typename TestFixture::TestCase;
int num_nodes = 2;
const int num_nodes = 2;
std::vector<NodeState> nodes(num_nodes);
// First do a normal all-reduce to simulate the the case when there're
// multiple communicators.
this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1);
// Use a new communicator_key, which uses a new set of ncclComm underneath.
string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
string collective_key = "allreduce";
const string collective_key = "allreduce";
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(0);
auto node_fn = [&](TestCase* test_case, int node) {
auto node_fn = [&](TestCase* test_case, int node,
const string& communicator_key) {
auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node,
/* local_rank */ 0);
auto* info = device->tensorflow_gpu_device_info();
@ -894,6 +892,8 @@ TYPED_TEST(NcclManagerTest, Abort) {
nodes[node].nccl_manager.SignalMultiNodeReady(collective_key);
};
// Use a new communicator_key, which uses a new set of ncclComm underneath.
string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
// Do a normal all-reduce with this communicator key to initialize ncclComm.
// This is because ncclCommInitRank waits for all ranks and is blocking.
{
@ -903,7 +903,9 @@ TYPED_TEST(NcclManagerTest, Abort) {
TensorShape({2, 3}), 0.0f));
for (int i = 0; i < num_nodes; ++i) {
this->work_queue_->Schedule(
[&node_fn, &test_case, i]() { node_fn(test_case.get(), i); });
[&node_fn, &test_case, i, communicator_key]() {
node_fn(test_case.get(), i, communicator_key);
});
}
this->VerifyResults(test_case.get());
}
@ -914,16 +916,41 @@ TYPED_TEST(NcclManagerTest, Abort) {
this->MakeReductionTestCase(
/* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
TensorShape({2, 3}), 0.0f));
node_fn(test_case.get(), 0);
node_fn(test_case.get(), 0, communicator_key);
Env::Default()->SleepForMicroseconds(1000000);
nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down"));
for (auto& node : nodes) {
node.nccl_manager.StartAbort(errors::Unavailable("peer down"));
}
{
mutex_lock l(test_case->mu);
while (test_case->num_completed != 1) {
test_case->done_cv.wait(l);
}
}
// Reset the aborted NcclManager and then run another all-reduce with the
// resetted NcclManagers.
for (auto& node : nodes) {
node.nccl_manager.Reset();
}
// Regenerate the communicator_key, because this is needed to create new
// communicators.
communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
{
std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeReductionTestCase(
/* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
TensorShape({2, 3}), 0.0f));
for (int i = 0; i < num_nodes; ++i) {
this->work_queue_->Schedule(
[&node_fn, &test_case, i, communicator_key]() {
node_fn(test_case.get(), i, communicator_key);
});
}
this->VerifyResults(test_case.get());
}
}
#endif
} // namespace tensorflow