Fix a use-after-free when collective times out

Since NcclCommunicator is owned by CollectiveExecutorMgr, and CollectiveExecutorMgr may be destructed while the timeout callback is executing, there could be use-after-free.

Given the timeout is a feature only for debugging, it doesn't seem worth the cost of changing everything to shared_ptr. Instead, we keep the done() that times out after the StartAbort(), to ensure that all the resources are still available because there're still one pending kernel.

The change in NcclManager is mostly to silence asan.

PiperOrigin-RevId: 337928891
Change-Id: I855979ade4e0309ac2c4e2cd9c4d844588ee6b4f
This commit is contained in:
Ran Chen 2020-10-19 14:06:21 -07:00 committed by TensorFlower Gardener
parent c031923133
commit af93956653
4 changed files with 38 additions and 43 deletions

View File

@ -227,31 +227,27 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams& col_params, const CollectiveParams& col_params,
const string& exec_key, const string& exec_key,
StatusCallback done) { StatusCallback done) {
// See CompleteParamsAsync() how done() and the timeout callback interacts.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_safe = [done, is_callback_called](const Status& s) {
StatusCallback done_safe = [done = std::move(done),
is_callback_called](const Status& s) {
bool called = is_callback_called->exchange(true); bool called = is_callback_called->exchange(true);
CHECK(!called) << "done callback is called twice in " // Crash OK if (!called) {
"BaseCollectiveExecutor::ExecuteAsync. Please file a " done(s);
"issue on https://github.com/tensorflow/tensorflow."; }
done(s);
}; };
auto timeout_microseconds = static_cast<int64>( auto timeout_microseconds = static_cast<int64>(
col_params.instance.impl_details.timeout_seconds * 1'000'000); col_params.instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) { if (timeout_microseconds > 0) {
// Ensure this BaseCollectiveExecutor is alive when StartAbort() is called.
Ref();
// TODO(xldrx): Share the timeout watchdog thread among collectives. // TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter( SchedNonBlockingClosureAfter(
timeout_microseconds, [this, is_callback_called] { timeout_microseconds, [this, is_callback_called, done] {
if (!is_callback_called->load()) { bool called = is_callback_called->exchange(true);
if (!called) {
Status status(error::DEADLINE_EXCEEDED, Status status(error::DEADLINE_EXCEEDED,
"Collective has timed out during execution."); "Collective has timed out during execution.");
StartAbort(status); StartAbort(status);
done(status);
} }
Unref();
}); });
} }
@ -306,30 +302,34 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
const DeviceAttributes& device, CollectiveParams* cp, const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, StatusCallback done) { CancellationManager* cancel_mgr, StatusCallback done) {
cp->group.gpu_ring_order = *gpu_ring_order_; cp->group.gpu_ring_order = *gpu_ring_order_;
// We need to make sure that when the timeout callback executes,
// CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
// is called, CollectiveExecutorMgr may be destructed and we don't have a way
// to keep it without making the ownerships more complicated. Therefore if the
// timeout callback executes, done_safe will become a no-op and the timeout
// callback is responsible for invoking done() at the end.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_safe = [is_callback_called, auto done_safe = [done, is_callback_called](const Status& s) {
done = std::move(done)](const Status& s) {
bool called = is_callback_called->exchange(true); bool called = is_callback_called->exchange(true);
CHECK(!called) << "done callback is called twice in " // Crash OK if (!called) {
"BaseCollectiveExecutor::ExecuteAsync. Please file a " done(s);
"issue on https://github.com/tensorflow/tensorflow."; }
done(s);
}; };
auto timeout_microseconds = auto timeout_microseconds =
static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000); static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) { if (timeout_microseconds > 0) {
// Ensure this BaseCollectiveExecutor is alive when StartAbort() is called.
Ref();
// TODO(xldrx): Share the timeout watchdog thread among collectives. // TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(timeout_microseconds, [this, SchedNonBlockingClosureAfter(
is_callback_called]() { timeout_microseconds, [this, is_callback_called, done]() {
if (!is_callback_called->load()) { bool called = is_callback_called->exchange(true);
Status status(error::DEADLINE_EXCEEDED, if (!called) {
"Collective has timed out waiting for other workers."); Status status(
StartAbort(status); error::DEADLINE_EXCEEDED,
} "Collective has timed out waiting for other workers.");
Unref(); StartAbort(status);
}); done(status);
}
});
} }
cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
done_safe); done_safe);

View File

@ -115,7 +115,7 @@ struct NcclManager::Communicator {
: num_devices(members.size()), members(std::move(members)), key(key) {} : num_devices(members.size()), members(std::move(members)), key(key) {}
const int num_devices; const int num_devices;
const std::vector<CommunicatorMember> members; std::vector<CommunicatorMember> members;
const string key; const string key;
}; };
@ -851,8 +851,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
void NcclManager::StartAbort(const Status& s) { void NcclManager::StartAbort(const Status& s) {
VLOG(1) << "NcclManager StartAbort"; VLOG(1) << "NcclManager StartAbort";
absl::flat_hash_map<string, Collective*> collectives; absl::flat_hash_map<string, Collective*> collectives;
// After status_ is set to a non-OK one, there should be no further std::vector<std::unique_ptr<Communicator>> communicators;
// modifications to collectives_.
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!status_.ok()) { if (!status_.ok()) {
@ -863,6 +862,7 @@ void NcclManager::StartAbort(const Status& s) {
} }
status_ = s; status_ = s;
collectives.swap(collectives_); collectives.swap(collectives_);
communicators.swap(communicators_);
} }
// collectives_ contains pending launches that haven't been dispatched to // collectives_ contains pending launches that haven't been dispatched to
// kernel launch threads, so we can simply invoke the done callbacks of them. // kernel launch threads, so we can simply invoke the done callbacks of them.
@ -879,14 +879,15 @@ void NcclManager::StartAbort(const Status& s) {
// there's only one active NcclManager at a time. // there's only one active NcclManager at a time.
UnboundedWorkQueue queue(Env::Default(), "nccl_abort"); UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
int num_comms = 0; int num_comms = 0;
for (std::unique_ptr<Communicator>& communicator : communicators_) { for (std::unique_ptr<Communicator>& communicator : communicators) {
num_comms += communicator->members.size(); num_comms += communicator->members.size();
} }
BlockingCounter pending(num_comms); BlockingCounter pending(num_comms);
for (std::unique_ptr<Communicator>& communicator : communicators_) { for (std::unique_ptr<Communicator>& communicator : communicators) {
for (const CommunicatorMember& member : communicator->members) { for (CommunicatorMember& member : communicator->members) {
queue.Schedule([&member, &pending]() { queue.Schedule([&member, &pending]() {
ncclCommAbort(member.nccl_comm); ncclCommAbort(member.nccl_comm);
member.nccl_comm = nullptr;
pending.DecrementCount(); pending.DecrementCount();
}); });
} }

View File

@ -248,7 +248,7 @@ class NcclManager {
absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>> absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>>
device_to_comm_streams_ TF_GUARDED_BY(mu_); device_to_comm_streams_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<Communicator>> communicators_; std::vector<std::unique_ptr<Communicator>> communicators_ TF_GUARDED_BY(mu_);
Status status_ TF_GUARDED_BY(mu_); Status status_ TF_GUARDED_BY(mu_);

View File

@ -491,8 +491,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
def testTimeout(self, collective_op, device, communication): def testTimeout(self, collective_op, device, communication):
if device == 'GPU':
self.skipTest('b/170980122')
timeout = 1.5 timeout = 1.5
@def_function.function @def_function.function
@ -527,8 +525,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
def testParamResolutionAfterTimeout(self, collective_op, device, def testParamResolutionAfterTimeout(self, collective_op, device,
communication): communication):
if device == 'GPU':
self.skipTest('b/170980122')
dev0 = '/device:%s:0' % device dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device dev1 = '/device:%s:1' % device
timeout = 1.5 timeout = 1.5
@ -564,8 +560,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
communication_hint=communication) communication_hint=communication)
def testExecutionAfterTimeout(self, collective_op, device, communication): def testExecutionAfterTimeout(self, collective_op, device, communication):
if device == 'GPU':
self.skipTest('b/170980122')
dev0 = '/device:%s:0' % device dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device dev1 = '/device:%s:1' % device
timeout = 1.5 timeout = 1.5