[XLA:GPU] Warn if threads in collective operation thunks appear stuck.

This would have helped the JAX team yesterday when debugging a stuck
all-reduce.

I have more simplifications I want to make to nccl_all_reduce_thunk, but
keeping this change small.

PiperOrigin-RevId: 253621853
This commit is contained in:
Justin Lebar 2019-06-17 11:25:20 -07:00 committed by TensorFlower Gardener
parent 700c5d8eb1
commit d9a37895d6
2 changed files with 113 additions and 34 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h" #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
#include <chrono> // NOLINT (required by TF interfaces)
#include <map> #include <map>
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -40,6 +41,26 @@ namespace {
using tensorflow::BlockingCounter; using tensorflow::BlockingCounter;
// This same function appears in nccl_all_reduce_thunk. I've copy/pasted it
// here primarily because I want the VLOGs to work.
template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
}
// Key for looking up a Rendezvous object in our global map. // Key for looking up a Rendezvous object in our global map.
// //
// Morally, the key is just a RunId. num_participants is in this struct only // Morally, the key is just a RunId. num_participants is in this struct only
@ -48,6 +69,11 @@ struct RendezvousKey {
RunId run_id; RunId run_id;
int64 num_participants; int64 num_participants;
string ToString() const {
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
run_id.ToString(), num_participants);
}
template <typename H> template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) { friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id); return H::combine(std::move(h), k.run_id);
@ -82,11 +108,11 @@ struct ParticipantData {
// Rendezvous objects can only be used once. // Rendezvous objects can only be used once.
class Rendezvous { class Rendezvous {
public: public:
explicit Rendezvous(int64 num_participants) explicit Rendezvous(const RendezvousKey& key)
: num_participants_(num_participants), : key_(key),
all_arrived_(num_participants), all_arrived_(key.num_participants),
returned_blocking_counter_( returned_blocking_counter_(
std::make_shared<BlockingCounter>(num_participants)) {} std::make_shared<BlockingCounter>(key.num_participants)) {}
// Runs the collective permute on the given thread. // Runs the collective permute on the given thread.
// //
@ -98,7 +124,7 @@ class Rendezvous {
ParticipantData participant); ParticipantData participant);
private: private:
const int64 num_participants_; const RendezvousKey key_;
BlockingCounter all_arrived_; BlockingCounter all_arrived_;
// BlockingCounter returned by SubmitParticipant. // BlockingCounter returned by SubmitParticipant.
@ -146,7 +172,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
if (primary) { if (primary) {
initialized_ = true; initialized_ = true;
returned_blocking_counter_ = returned_blocking_counter_ =
std::make_shared<BlockingCounter>(num_participants_); std::make_shared<BlockingCounter>(key_.num_participants);
} }
} }
@ -155,7 +181,13 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
// everyone, then we wouldn't be able to enqueue the copies at the correct // everyone, then we wouldn't be able to enqueue the copies at the correct
// point in their streams. // point in their streams.
all_arrived_.DecrementCount(); all_arrived_.DecrementCount();
all_arrived_.Wait(); WaitAndLogIfStuck(&all_arrived_, [&] {
return absl::StrFormat(
"participant for replica %d (stream %p, device %d) waiting for all "
"other participants to arrive: %s",
participant.replica_id, participant.stream,
participant.stream->parent()->device_ordinal(), key_.ToString());
});
// Schedule the copies between the devices. This is much easier to reason // Schedule the copies between the devices. This is much easier to reason
// about if we schedule all of the copies from just one thread. The copies // about if we schedule all of the copies from just one thread. The copies
@ -185,7 +217,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() { RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>( static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
[](const RendezvousKey& key) { [](const RendezvousKey& key) {
return absl::make_unique<Rendezvous>(key.num_participants); return absl::make_unique<Rendezvous>(key);
}); });
return m; return m;
} }
@ -245,7 +277,13 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
// erase() is deceptively complex to implement correctly. // erase() is deceptively complex to implement correctly.
rendezvous.reset(); rendezvous.reset();
final_sync->DecrementCount(); final_sync->DecrementCount();
final_sync->Wait(); WaitAndLogIfStuck(final_sync.get(), [&] {
return absl::StrFormat(
"participant for replica %d (stream %p, device ordinal %d) waiting for "
"all threads to drop their reference to the rendezvous: %s",
participant.replica_id, participant.stream,
participant.stream->parent()->device_ordinal(), key.ToString());
});
return Status::OK(); return Status::OK();
} }

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA
#include <chrono> // NOLINT (required by TF interfaces)
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
@ -114,6 +115,24 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
} \ } \
} while (0) } while (0)
template <typename DescFn>
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
const DescFn& desc_fn) {
VLOG(3) << "Begin: " << desc_fn();
const std::chrono::milliseconds timeout(5000);
bool ok = counter->WaitFor(timeout);
if (ok) {
VLOG(3) << "Finished: " << desc_fn();
return;
}
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
<< "ms for and may be stuck: " << desc_fn();
counter->Wait();
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
"Perhaps the timeout is too short: "
<< desc_fn();
}
// RAII class owning a ncclComm_t, ensuring it doesn't leak. // RAII class owning a ncclComm_t, ensuring it doesn't leak.
class NcclComm { class NcclComm {
public: public:
@ -390,7 +409,8 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
// Rendezvous objects can only be used once. // Rendezvous objects can only be used once.
class Rendezvous { class Rendezvous {
public: public:
Rendezvous() = default; explicit Rendezvous(const RendezvousKey& k)
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
// Runs the all-reduce on the given thread. If successful, returns // Runs the all-reduce on the given thread. If successful, returns
// - a handle to the clique that was used, so that the caller may keep the // - a handle to the clique that was used, so that the caller may keep the
@ -405,8 +425,10 @@ class Rendezvous {
private: private:
Status DoAllReduce(ParticipantData participant, ncclComm_t comm); Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_;
tensorflow::mutex mu_; tensorflow::mutex mu_;
tensorflow::condition_variable all_participants_present_;
bool initialized_ GUARDED_BY(mu_) = false; bool initialized_ GUARDED_BY(mu_) = false;
absl::optional<tensorflow::BlockingCounter> done_; absl::optional<tensorflow::BlockingCounter> done_;
@ -424,23 +446,14 @@ class Rendezvous {
// Rendezvous objects are one-time use, so they're removed from this map once // Rendezvous objects are one-time use, so they're removed from this map once
// we're through with them. // we're through with them.
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() { RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(); static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
[](const RendezvousKey& k) { return absl::make_unique<Rendezvous>(k); });
return m; return m;
} }
StatusOr<std::pair<std::shared_ptr<NcclClique>, StatusOr<std::pair<std::shared_ptr<NcclClique>,
std::shared_ptr<tensorflow::BlockingCounter>>> std::shared_ptr<tensorflow::BlockingCounter>>>
Rendezvous::SubmitParticipant(ParticipantData participant) { Rendezvous::SubmitParticipant(ParticipantData participant) {
// We pull into our thread a) the communication handle and b) whether we're
// the "primary" thread for this rendezvous -- the "primary" thread has some
// additional responsibilities for setup/teardown.
ncclComm_t comm;
bool primary;
std::shared_ptr<NcclClique> clique;
// Releases the lock on the clique (held only by the primary thread).
Cleanup<std::function<void()>> clique_lock_releaser;
{ {
tensorflow::mutex_lock lock(mu_); tensorflow::mutex_lock lock(mu_);
CHECK(!initialized_); CHECK(!initialized_);
@ -455,14 +468,29 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
participants_.back().ToString(), participant.ToString()); participants_.back().ToString(), participant.ToString());
} }
participants_.push_back(participant); participants_.push_back(participant);
}
// Wait here for all participants to arrive. // Wait for all participants to arrive.
while (participants_.size() < participant.num_participants()) { all_participants_present_.DecrementCount();
all_participants_present_.wait(lock); WaitAndLogIfStuck(&all_participants_present_, [&] {
} return absl::StrFormat(
if (participants_.size() == participant.num_participants()) { "participant for device ordinal %d, stream %p waiting for all "
all_participants_present_.notify_all(); "participants to be arrive at rendezvous %s",
} participant.device_ordinal, participant.stream, key_.ToString());
});
// We pull into our thread a) the communication handle and b) whether we're
// the "primary" thread for this rendezvous -- the "primary" thread has some
// additional responsibilities for setup/teardown.
ncclComm_t comm;
bool primary;
std::shared_ptr<NcclClique> clique;
// Releases the lock on the clique (held only by the primary thread).
Cleanup<std::function<void()>> clique_lock_releaser;
{
tensorflow::mutex_lock lock(mu_);
// The first thread to get here has additional responsibilities, such as // The first thread to get here has additional responsibilities, such as
// ensuring that there's a NCCL clique available for us to use. // ensuring that there's a NCCL clique available for us to use.
@ -512,10 +540,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
// The primary owns the lock on the NCCL clique. Hold it until all threads // The primary owns the lock on the NCCL clique. Hold it until all threads
// are done. (We'll release it when we return from this function.) // are done. (We'll release it when we return from this function.)
if (primary) { if (primary) {
VLOG(3) WaitAndLogIfStuck(&*done_, [&] {
<< "Primary waiting for all participants to complete all-reduce op."; return absl::StrFormat(
done_->Wait(); "primary participant (device ordinal %d, stream %p) waiting for all "
VLOG(3) << "All participants completed all-reduce op."; "other participants to complete all-reduce %s",
participant.device_ordinal, participant.stream, key_.ToString());
});
} }
VLOG(3) << "Returning status: " << all_reduce_status; VLOG(3) << "Returning status: " << all_reduce_status;
@ -624,6 +654,7 @@ static StatusOr<std::vector<int64>> GetParticipatingReplicas(
} }
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(1) << "Starting NcclAllReduceThunk.";
auto op_profiler = auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
@ -642,6 +673,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
std::shared_ptr<Rendezvous> rendezvous = std::shared_ptr<Rendezvous> rendezvous =
GlobalRendezvousMap()[rendezvous_key]; GlobalRendezvousMap()[rendezvous_key];
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
<< ", rendezvous: " << rendezvous.get()
<< ", participating replicas: "
<< absl::StrJoin(participating_replicas, ", ");
ParticipantData participant(rendezvous_key); ParticipantData participant(rendezvous_key);
participant.element_count = element_count_; participant.element_count = element_count_;
participant.device_ordinal = device_ordinal; participant.device_ordinal = device_ordinal;
@ -682,7 +718,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
// erase() is deceptively complex to implement correctly. // erase() is deceptively complex to implement correctly.
rendezvous.reset(); rendezvous.reset();
blocking_counter->DecrementCount(); blocking_counter->DecrementCount();
blocking_counter->Wait(); WaitAndLogIfStuck(blocking_counter.get(), [&] {
return absl::StrFormat(
"participant for device ordinal %d, stream %p waiting for "
"all threads to drop their reference to the rendezvous: %s",
device_ordinal, params.stream, rendezvous_key.ToString());
});
return Status::OK(); return Status::OK();
} }