Fix a use-after-free when collective times out

Since NcclCommunicator is owned by CollectiveExecutorMgr, and CollectiveExecutorMgr may be destructed while the timeout callback is executing, there could be use-after-free.

Given the timeout is a feature only for debugging, it doesn't seem worth the cost of changing everything to shared_ptr. Instead, we keep the done() that times out after the StartAbort(), to ensure that all the resources are still available because there're still one pending kernel.

The change in NcclManager is mostly to silence asan.

PiperOrigin-RevId: 337928891
Change-Id: I855979ade4e0309ac2c4e2cd9c4d844588ee6b4f
This commit is contained in:
Ran Chen 2020-10-19 14:06:21 -07:00 committed by TensorFlower Gardener
parent c031923133
commit af93956653
4 changed files with 38 additions and 43 deletions

View File

@ -227,31 +227,27 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams& col_params,
const string& exec_key,
StatusCallback done) {
// See CompleteParamsAsync() how done() and the timeout callback interacts.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
StatusCallback done_safe = [done = std::move(done),
is_callback_called](const Status& s) {
auto done_safe = [done, is_callback_called](const Status& s) {
bool called = is_callback_called->exchange(true);
CHECK(!called) << "done callback is called twice in " // Crash OK
"BaseCollectiveExecutor::ExecuteAsync. Please file a "
"issue on https://github.com/tensorflow/tensorflow.";
done(s);
if (!called) {
done(s);
}
};
auto timeout_microseconds = static_cast<int64>(
col_params.instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// Ensure this BaseCollectiveExecutor is alive when StartAbort() is called.
Ref();
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(
timeout_microseconds, [this, is_callback_called] {
if (!is_callback_called->load()) {
timeout_microseconds, [this, is_callback_called, done] {
bool called = is_callback_called->exchange(true);
if (!called) {
Status status(error::DEADLINE_EXCEEDED,
"Collective has timed out during execution.");
StartAbort(status);
done(status);
}
Unref();
});
}
@ -306,30 +302,34 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, StatusCallback done) {
cp->group.gpu_ring_order = *gpu_ring_order_;
// We need to make sure that when the timeout callback executes,
// CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
// is called, CollectiveExecutorMgr may be destructed and we don't have a way
// to keep it without making the ownerships more complicated. Therefore if the
// timeout callback executes, done_safe will become a no-op and the timeout
// callback is responsible for invoking done() at the end.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_safe = [is_callback_called,
done = std::move(done)](const Status& s) {
auto done_safe = [done, is_callback_called](const Status& s) {
bool called = is_callback_called->exchange(true);
CHECK(!called) << "done callback is called twice in " // Crash OK
"BaseCollectiveExecutor::ExecuteAsync. Please file a "
"issue on https://github.com/tensorflow/tensorflow.";
done(s);
if (!called) {
done(s);
}
};
auto timeout_microseconds =
static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// Ensure this BaseCollectiveExecutor is alive when StartAbort() is called.
Ref();
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(timeout_microseconds, [this,
is_callback_called]() {
if (!is_callback_called->load()) {
Status status(error::DEADLINE_EXCEEDED,
"Collective has timed out waiting for other workers.");
StartAbort(status);
}
Unref();
});
SchedNonBlockingClosureAfter(
timeout_microseconds, [this, is_callback_called, done]() {
bool called = is_callback_called->exchange(true);
if (!called) {
Status status(
error::DEADLINE_EXCEEDED,
"Collective has timed out waiting for other workers.");
StartAbort(status);
done(status);
}
});
}
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
done_safe);

View File

@ -115,7 +115,7 @@ struct NcclManager::Communicator {
: num_devices(members.size()), members(std::move(members)), key(key) {}
const int num_devices;
const std::vector<CommunicatorMember> members;
std::vector<CommunicatorMember> members;
const string key;
};
@ -851,8 +851,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
void NcclManager::StartAbort(const Status& s) {
VLOG(1) << "NcclManager StartAbort";
absl::flat_hash_map<string, Collective*> collectives;
// After status_ is set to a non-OK one, there should be no further
// modifications to collectives_.
std::vector<std::unique_ptr<Communicator>> communicators;
{
mutex_lock l(mu_);
if (!status_.ok()) {
@ -863,6 +862,7 @@ void NcclManager::StartAbort(const Status& s) {
}
status_ = s;
collectives.swap(collectives_);
communicators.swap(communicators_);
}
// collectives_ contains pending launches that haven't been dispatched to
// kernel launch threads, so we can simply invoke the done callbacks of them.
@ -879,14 +879,15 @@ void NcclManager::StartAbort(const Status& s) {
// there's only one active NcclManager at a time.
UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
int num_comms = 0;
for (std::unique_ptr<Communicator>& communicator : communicators_) {
for (std::unique_ptr<Communicator>& communicator : communicators) {
num_comms += communicator->members.size();
}
BlockingCounter pending(num_comms);
for (std::unique_ptr<Communicator>& communicator : communicators_) {
for (const CommunicatorMember& member : communicator->members) {
for (std::unique_ptr<Communicator>& communicator : communicators) {
for (CommunicatorMember& member : communicator->members) {
queue.Schedule([&member, &pending]() {
ncclCommAbort(member.nccl_comm);
member.nccl_comm = nullptr;
pending.DecrementCount();
});
}

View File

@ -248,7 +248,7 @@ class NcclManager {
absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>>
device_to_comm_streams_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<Communicator>> communicators_;
std::vector<std::unique_ptr<Communicator>> communicators_ TF_GUARDED_BY(mu_);
Status status_ TF_GUARDED_BY(mu_);

View File

@ -491,8 +491,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
super().setUp()
def testTimeout(self, collective_op, device, communication):
if device == 'GPU':
self.skipTest('b/170980122')
timeout = 1.5
@def_function.function
@ -527,8 +525,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
def testParamResolutionAfterTimeout(self, collective_op, device,
communication):
if device == 'GPU':
self.skipTest('b/170980122')
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
timeout = 1.5
@ -564,8 +560,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
communication_hint=communication)
def testExecutionAfterTimeout(self, collective_op, device, communication):
if device == 'GPU':
self.skipTest('b/170980122')
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
timeout = 1.5