Add NcclManager::StartAbort

This allows to abort NcclManager when the cluster is unhealthy. After the abortion, any subsequent call to NcclManager will error immediately.

After calling ncclCommAbort, ongoing and subsequent nccl launches should error.

Note that this cannot abort NCCL initialization yet.

PiperOrigin-RevId: 332512565
Change-Id: I1cc53078f6cddeeea566ed18edbb45f7a2d833a0
This commit is contained in:
Ran Chen 2020-09-18 13:12:47 -07:00 committed by TensorFlower Gardener
parent 46d5b08525
commit f2ebefba65
7 changed files with 284 additions and 84 deletions

View File

@ -218,6 +218,9 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) {
VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
cem_->GetParamResolver()->StartAbort(s);
remote_access_->StartAbort(s);
if (cem_->GetNcclCommunicator() != nullptr) {
cem_->GetNcclCommunicator()->StartAbort(s);
}
}
void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,

View File

@ -102,7 +102,6 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
}
NcclCommunicatorInterface* GetNcclCommunicator() const override {
LOG(FATAL) << "Unimplemented"; // Crash OK
return nullptr;
}

View File

@ -164,7 +164,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
}
void NcclCommunicator::StartAbort(const Status& s) {
CHECK(false) << "not implemented yet"; // Crash ok.
nccl_manager_.StartAbort(s);
}
} // namespace tensorflow

View File

@ -22,7 +22,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/unbounded_work_queue.h"
#include "tensorflow/core/profiler/lib/annotated_traceme.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -279,6 +281,9 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
});
mutex_lock l(mu_);
if (!status_.ok()) {
return status_;
}
if (collective->communicator_key.empty()) {
// For single-node collectives, when the caller does not specify a
@ -487,6 +492,7 @@ void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
ncclRedOp_t reduction_op) {
Collective* to_run = nullptr;
DataType data_type;
Status nccl_manager_status;
if (participant->input != nullptr) {
data_type = participant->input->dtype();
} else {
@ -494,92 +500,100 @@ void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
}
{
mutex_lock l(mu_);
auto collective_it = collectives_.find(context.collective_key);
Collective* collective = nullptr;
if (collective_it == collectives_.end()) {
collective =
new Collective(context.collective_key, data_type, collective_type,
reduction_op, context.num_local_devices,
context.num_global_devices, context.communicator_key);
collectives_.emplace(context.collective_key, collective);
} else {
collective = collective_it->second;
}
nccl_manager_status = status_;
if (nccl_manager_status.ok()) {
auto collective_it = collectives_.find(context.collective_key);
Collective* collective = nullptr;
if (collective_it == collectives_.end()) {
collective = new Collective(
context.collective_key, data_type, collective_type, reduction_op,
context.num_local_devices, context.num_global_devices,
context.communicator_key);
collectives_.emplace(context.collective_key, collective);
} else {
collective = collective_it->second;
}
// Check `collective` is correct and consistent.
if (collective->status.ok() && !collective->single_node &&
collective->communicator_key.empty()) {
collective->status = errors::Internal(
"Collective ", reduction_op, " is multi node with num_local_devices=",
collective->num_local_devices,
" and num_global_devices=", collective->num_global_devices,
" but has an empty communicator_key");
}
if (collective->status.ok() && collective->communicator_key.size() !=
context.communicator_key.size()) {
collective->status =
errors::Internal("Collective ", reduction_op,
" mismatch in member communicator_key with size ",
collective->communicator_key.size(),
" and arg communicator_key with size ",
context.communicator_key.size());
}
if (collective->status.ok() && collective->type != collective_type) {
collective->status = errors::Internal(
"Collective ", reduction_op, " previously initialized with type ",
collective->type, " but now got type ", collective_type);
}
if (collective->status.ok() &&
collective->num_global_devices != context.num_global_devices) {
collective->status =
errors::Internal("Collective ", reduction_op,
" previously initialized with num_global_devices ",
collective->num_global_devices, " but now got ",
context.num_global_devices);
}
if (collective->status.ok() &&
collective->num_local_devices != context.num_local_devices) {
collective->status =
errors::Internal("Collective ", reduction_op,
"previously initialized with num_local_devices ",
collective->num_local_devices, " but now got ",
context.num_local_devices);
}
if (collective->status.ok() &&
collective->participants.size() >= collective->num_local_devices) {
collective->status = errors::Internal(
"Collective ", reduction_op, " expected ",
collective->num_local_devices, " participants but now has ",
collective->participants.size(),
" with one more participant being added");
}
if (collective->status.ok() && collective->root_rank >= 0 &&
context.source_rank >= 0 &&
collective->root_rank != context.source_rank) {
collective->status = errors::Internal(
"Collective ", collective->collective_key, " already has root_rank ",
collective->root_rank, " but new participant has root_rank ",
context.source_rank);
}
if (collective->status.ok() &&
!kValidDataTypes.Contains(collective->data_type)) {
collective->status = errors::Internal(
"Collective ", collective->collective_key,
" expected data types compatible with NCCL but instead got ",
DataTypeString(collective->data_type));
}
// Check `collective` is correct and consistent.
if (collective->status.ok() && !collective->single_node &&
collective->communicator_key.empty()) {
collective->status = errors::Internal(
"Collective ", reduction_op,
" is multi node with num_local_devices=",
collective->num_local_devices,
" and num_global_devices=", collective->num_global_devices,
" but has an empty communicator_key");
}
if (collective->status.ok() && collective->communicator_key.size() !=
context.communicator_key.size()) {
collective->status =
errors::Internal("Collective ", reduction_op,
" mismatch in member communicator_key with size ",
collective->communicator_key.size(),
" and arg communicator_key with size ",
context.communicator_key.size());
}
if (collective->status.ok() && collective->type != collective_type) {
collective->status = errors::Internal(
"Collective ", reduction_op, " previously initialized with type ",
collective->type, " but now got type ", collective_type);
}
if (collective->status.ok() &&
collective->num_global_devices != context.num_global_devices) {
collective->status =
errors::Internal("Collective ", reduction_op,
" previously initialized with num_global_devices ",
collective->num_global_devices, " but now got ",
context.num_global_devices);
}
if (collective->status.ok() &&
collective->num_local_devices != context.num_local_devices) {
collective->status =
errors::Internal("Collective ", reduction_op,
"previously initialized with num_local_devices ",
collective->num_local_devices, " but now got ",
context.num_local_devices);
}
if (collective->status.ok() &&
collective->participants.size() >= collective->num_local_devices) {
collective->status = errors::Internal(
"Collective ", reduction_op, " expected ",
collective->num_local_devices, " participants but now has ",
collective->participants.size(),
" with one more participant being added");
}
if (collective->status.ok() && collective->root_rank >= 0 &&
context.source_rank >= 0 &&
collective->root_rank != context.source_rank) {
collective->status = errors::Internal(
"Collective ", collective->collective_key,
" already has root_rank ", collective->root_rank,
" but new participant has root_rank ", context.source_rank);
}
if (collective->status.ok() &&
!kValidDataTypes.Contains(collective->data_type)) {
collective->status = errors::Internal(
"Collective ", collective->collective_key,
" expected data types compatible with NCCL but instead got ",
DataTypeString(collective->data_type));
}
if (context.source_rank >= 0) {
collective->root_rank = context.source_rank;
}
collective->participants.emplace_back(std::move(participant));
++collective->available_participants;
if (context.source_rank >= 0) {
collective->root_rank = context.source_rank;
}
if (CheckReady(context.collective_key, collective)) {
to_run = collective;
collective->participants.emplace_back(std::move(participant));
++collective->available_participants;
if (CheckReady(context.collective_key, collective)) {
to_run = collective;
}
}
}
if (!nccl_manager_status.ok()) {
participant->done_callback(nccl_manager_status);
return;
}
if (to_run != nullptr) RunCollective(to_run);
}
@ -834,6 +848,52 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
}
}
void NcclManager::StartAbort(const Status& s) {
VLOG(1) << "NcclManager StartAbort";
absl::flat_hash_map<string, Collective*> collectives;
// After status_ is set to a non-OK one, there should be no further
// modifications to collectives_.
{
mutex_lock l(mu_);
if (!status_.ok()) {
LOG(WARNING)
<< "NcclManager already aborted, ignoring subsequent StartAbort with "
<< s;
return;
}
status_ = s;
collectives.swap(collectives_);
}
// collectives_ contains pending launches that haven't been dispatched to
// kernel launch threads, so we can simply invoke the done callbacks of them.
for (const auto& item : collectives) {
for (const std::unique_ptr<Participant>& p : item.second->participants) {
p->done_callback(s);
}
item.second->Unref();
}
// Abort ncclComm. Note that there could be multiple ncclComm per device, and
// ncclCommAbort contains cuda calls that requires device synchronization.
// That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1),
// so we need to abort all ncclComm in a concurrent fashion. This assumes that
// there's only one active NcclManager at a time.
UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
int num_comms = 0;
for (std::unique_ptr<Communicator>& communicator : communicators_) {
num_comms += communicator->members.size();
}
BlockingCounter pending(num_comms);
for (std::unique_ptr<Communicator>& communicator : communicators_) {
for (const CommunicatorMember& member : communicator->members) {
queue.Schedule([&member, &pending]() {
ncclCommAbort(member.nccl_comm);
pending.DecrementCount();
});
}
}
pending.Wait();
}
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -202,6 +202,10 @@ class NcclManager {
// function.
void SignalMultiNodeReady(const string& collective_key);
// Aborts all collectives. After abortion, no further collectives can be
// launched with this NcclManager.
void StartAbort(const Status& s);
private:
enum CollectiveType {
kAllReduce = 1,
@ -257,6 +261,8 @@ class NcclManager {
std::vector<std::unique_ptr<Communicator>> communicators_;
Status status_ TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
};

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/nccl/nccl_manager.h"
@ -300,8 +302,14 @@ class NcclManagerTest : public ::testing::Test {
void RunMultiNodeAllReduceTest(const int num_nodes,
const int num_ranks_per_node) {
const int num_global_ranks = num_nodes * num_ranks_per_node;
std::vector<NodeState> node_states(num_nodes);
RunMultiNodeAllReduceTest(node_states, num_ranks_per_node);
}
void RunMultiNodeAllReduceTest(std::vector<NodeState>& node_states,
const int num_ranks_per_node) {
const int num_nodes = node_states.size();
const int num_global_ranks = num_nodes * num_ranks_per_node;
const string collective_key = "allreduce";
// The NcclManagers in this test synchronize in real-time, so we need to run
// each node's code in a separate thread.
@ -842,6 +850,68 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) {
this->VerifyError(test_case.get());
}
TYPED_TEST(NcclManagerTest, Abort) {
using NodeState = typename TestFixture::NodeState;
using TestCase = typename TestFixture::TestCase;
int num_nodes = 2;
std::vector<NodeState> nodes(num_nodes);
// First do a normal all-reduce to simulate the the case when there're
// multiple communicators.
this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1);
// Use a new communicator_key, which uses a new set of ncclComm underneath.
string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
string collective_key = "allreduce";
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(0);
auto node_fn = [&](TestCase* test_case, int node) {
auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node,
/* local_rank */ 0);
auto* info = device->tensorflow_gpu_device_info();
auto* stream = device->tensorflow_gpu_device_info()->stream;
auto participant = absl::make_unique<NcclManager::Participant>(
device->executor(), stream, info, &test_case->ins[node],
&test_case->outs[node], /* global_rank */ node,
this->CreateDoneCallback(test_case));
nodes[node].nccl_manager.AddToAllReduce(
std::move(participant),
{collective_key, /* num_local_devices */ 1,
/* num_global_devices */ num_nodes, communicator_key,
/*source_rank=*/-1},
reduction_op);
nodes[node].nccl_manager.SignalMultiNodeReady(collective_key);
};
// Do a normal all-reduce with this communicator key to initialize ncclComm.
// This is because ncclCommInitRank waits for all ranks and is blocking.
{
std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeReductionTestCase(
/* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
TensorShape({2, 3}), 0.0f));
for (int i = 0; i < num_nodes; ++i) {
this->work_queue_->Schedule(
[&node_fn, &test_case, i]() { node_fn(test_case.get(), i); });
}
this->VerifyResults(test_case.get());
}
// A hanging all-reduce.
ASSERT_GT(num_nodes, 1);
std::unique_ptr<typename TestFixture::TestCase> test_case(
this->MakeReductionTestCase(
/* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
TensorShape({2, 3}), 0.0f));
node_fn(test_case.get(), 0);
Env::Default()->SleepForMicroseconds(1000000);
nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down"));
{
mutex_lock l(test_case->mu);
while (test_case->num_completed != 1) {
test_case->done_cv.wait(l);
}
}
}
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
import os
import threading
import time
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -27,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import test
@ -301,6 +304,65 @@ class CollectiveOpGPUTest(test.TestCase):
[1.], group_size=1, group_key=0, instance_key=0, merge_op='Add',
final_op='Id', communication_hint='NCCL')
@test_util.run_v2_only
def testAbortNccl(self):
self._setup_context(num_gpus=2)
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant(1.)
# First perform a normal collective to finish resolution.
def collective_fn():
for device in ['GPU:0', 'GPU:1']:
with ops.device(device):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
'Add',
'Id',
communication_hint='nccl')
def_function.function(collective_fn)()
# Launch a collective that hangs, and abort the collective executor after
# the launch.
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
'Add',
'Id',
communication_hint='nccl')
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
'Add',
'Id',
communication_hint='nccl')
t.join()
# Reset the context in order to reset the collective executor.
context._reset_context() # pylint: disable=protected-access
def_function.function(collective_fn)()
if __name__ == '__main__':
test.main()