[XLA:GPU] Warn if threads in collective operation thunks appear stuck.

This would have helped the JAX team yesterday when debugging a stuck
all-reduce.

I have more simplifications I want to make to nccl_all_reduce_thunk, but
keeping this change small.

PiperOrigin-RevId: 253621853
This commit is contained in:
Justin Lebar 2019-06-17 11:25:20 -07:00 committed by TensorFlower Gardener
parent 700c5d8eb1
commit d9a37895d6
2 changed files with 113 additions and 34 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
#include <chrono> // NOLINT (required by TF interfaces)
#include <map>
#include <memory>
#include <vector>
@ -40,6 +41,26 @@ namespace {
using tensorflow::BlockingCounter;
// This same function appears in nccl_all_reduce_thunk. I've copy/pasted it
// here primarily because I want the VLOGs to work.
template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
}
// Key for looking up a Rendezvous object in our global map.
//
// Morally, the key is just a RunId. num_participants is in this struct only
@ -48,6 +69,11 @@ struct RendezvousKey {
RunId run_id;
int64 num_participants;
string ToString() const {
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
run_id.ToString(), num_participants);
}
template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id);
@ -82,11 +108,11 @@ struct ParticipantData {
// Rendezvous objects can only be used once.
class Rendezvous {
public:
explicit Rendezvous(int64 num_participants)
: num_participants_(num_participants),
all_arrived_(num_participants),
explicit Rendezvous(const RendezvousKey& key)
: key_(key),
all_arrived_(key.num_participants),
returned_blocking_counter_(
std::make_shared<BlockingCounter>(num_participants)) {}
std::make_shared<BlockingCounter>(key.num_participants)) {}
// Runs the collective permute on the given thread.
//
@ -98,7 +124,7 @@ class Rendezvous {
ParticipantData participant);
private:
const int64 num_participants_;
const RendezvousKey key_;
BlockingCounter all_arrived_;
// BlockingCounter returned by SubmitParticipant.
@ -146,7 +172,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
if (primary) {
initialized_ = true;
returned_blocking_counter_ =
std::make_shared<BlockingCounter>(num_participants_);
std::make_shared<BlockingCounter>(key_.num_participants);
}
}
@ -155,7 +181,13 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
// everyone, then we wouldn't be able to enqueue the copies at the correct
// point in their streams.
all_arrived_.DecrementCount();
all_arrived_.Wait();
WaitAndLogIfStuck(&all_arrived_, [&] {
return absl::StrFormat(
"participant for replica %d (stream %p, device %d) waiting for all "
"other participants to arrive: %s",
participant.replica_id, participant.stream,
participant.stream->parent()->device_ordinal(), key_.ToString());
});
// Schedule the copies between the devices. This is much easier to reason
// about if we schedule all of the copies from just one thread. The copies
@ -185,7 +217,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
[](const RendezvousKey& key) {
return absl::make_unique<Rendezvous>(key.num_participants);
return absl::make_unique<Rendezvous>(key);
});
return m;
}
@ -245,7 +277,13 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
// erase() is deceptively complex to implement correctly.
rendezvous.reset();
final_sync->DecrementCount();
final_sync->Wait();
WaitAndLogIfStuck(final_sync.get(), [&] {
return absl::StrFormat(
"participant for replica %d (stream %p, device ordinal %d) waiting for "
"all threads to drop their reference to the rendezvous: %s",
participant.replica_id, participant.stream,
participant.stream->parent()->device_ordinal(), key.ToString());
});
return Status::OK();
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#if GOOGLE_CUDA
#include <chrono> // NOLINT (required by TF interfaces)
#include <memory>
#include <string>
#include <utility>
@ -114,6 +115,24 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
} \
} while (0)
template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
}
// RAII class owning a ncclComm_t, ensuring it doesn't leak.
class NcclComm {
public:
@ -390,7 +409,8 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
// Rendezvous objects can only be used once.
class Rendezvous {
public:
Rendezvous() = default;
explicit Rendezvous(const RendezvousKey& k)
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
// Runs the all-reduce on the given thread. If successful, returns
// - a handle to the clique that was used, so that the caller may keep the
@ -405,8 +425,10 @@ class Rendezvous {
private:
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_;
tensorflow::mutex mu_;
tensorflow::condition_variable all_participants_present_;
bool initialized_ GUARDED_BY(mu_) = false;
absl::optional<tensorflow::BlockingCounter> done_;
@ -424,23 +446,14 @@ class Rendezvous {
// Rendezvous objects are one-time use, so they're removed from this map once
// we're through with them.
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>();
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
[](const RendezvousKey& k) { return absl::make_unique<Rendezvous>(k); });
return m;
}
StatusOr<std::pair<std::shared_ptr<NcclClique>,
std::shared_ptr<tensorflow::BlockingCounter>>>
Rendezvous::SubmitParticipant(ParticipantData participant) {
// We pull into our thread a) the communication handle and b) whether we're
// the "primary" thread for this rendezvous -- the "primary" thread has some
// additional responsibilities for setup/teardown.
ncclComm_t comm;
bool primary;
std::shared_ptr<NcclClique> clique;
// Releases the lock on the clique (held only by the primary thread).
Cleanup<std::function<void()>> clique_lock_releaser;
{
tensorflow::mutex_lock lock(mu_);
CHECK(!initialized_);
@ -455,14 +468,29 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
participants_.back().ToString(), participant.ToString());
}
participants_.push_back(participant);
}
// Wait here for all participants to arrive.
while (participants_.size() < participant.num_participants()) {
all_participants_present_.wait(lock);
}
if (participants_.size() == participant.num_participants()) {
all_participants_present_.notify_all();
}
// Wait for all participants to arrive.
all_participants_present_.DecrementCount();
WaitAndLogIfStuck(&all_participants_present_, [&] {
return absl::StrFormat(
"participant for device ordinal %d, stream %p waiting for all "
"participants to be arrive at rendezvous %s",
participant.device_ordinal, participant.stream, key_.ToString());
});
// We pull into our thread a) the communication handle and b) whether we're
// the "primary" thread for this rendezvous -- the "primary" thread has some
// additional responsibilities for setup/teardown.
ncclComm_t comm;
bool primary;
std::shared_ptr<NcclClique> clique;
// Releases the lock on the clique (held only by the primary thread).
Cleanup<std::function<void()>> clique_lock_releaser;
{
tensorflow::mutex_lock lock(mu_);
// The first thread to get here has additional responsibilities, such as
// ensuring that there's a NCCL clique available for us to use.
@ -512,10 +540,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
// The primary owns the lock on the NCCL clique. Hold it until all threads
// are done. (We'll release it when we return from this function.)
if (primary) {
VLOG(3)
<< "Primary waiting for all participants to complete all-reduce op.";
done_->Wait();
VLOG(3) << "All participants completed all-reduce op.";
WaitAndLogIfStuck(&*done_, [&] {
return absl::StrFormat(
"primary participant (device ordinal %d, stream %p) waiting for all "
"other participants to complete all-reduce %s",
participant.device_ordinal, participant.stream, key_.ToString());
});
}
VLOG(3) << "Returning status: " << all_reduce_status;
@ -624,6 +654,7 @@ static StatusOr<std::vector<int64>> GetParticipatingReplicas(
}
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(1) << "Starting NcclAllReduceThunk.";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
@ -642,6 +673,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
std::shared_ptr<Rendezvous> rendezvous =
GlobalRendezvousMap()[rendezvous_key];
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
<< ", rendezvous: " << rendezvous.get()
<< ", participating replicas: "
<< absl::StrJoin(participating_replicas, ", ");
ParticipantData participant(rendezvous_key);
participant.element_count = element_count_;
participant.device_ordinal = device_ordinal;
@ -682,7 +718,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
// erase() is deceptively complex to implement correctly.
rendezvous.reset();
blocking_counter->DecrementCount();
blocking_counter->Wait();
WaitAndLogIfStuck(blocking_counter.get(), [&] {
return absl::StrFormat(
"participant for device ordinal %d, stream %p waiting for "
"all threads to drop their reference to the rendezvous: %s",
device_ordinal, params.stream, rendezvous_key.ToString());
});
return Status::OK();
}