[XLA:GPU] Warn if threads in collective operation thunks appear stuck.
This would have helped the JAX team yesterday when debugging a stuck all-reduce. I have more simplifications I want to make to nccl_all_reduce_thunk, but keeping this change small. PiperOrigin-RevId: 253621853
This commit is contained in:
parent
700c5d8eb1
commit
d9a37895d6
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
|
||||||
|
|
||||||
|
#include <chrono> // NOLINT (required by TF interfaces)
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -40,6 +41,26 @@ namespace {
|
|||||||
|
|
||||||
using tensorflow::BlockingCounter;
|
using tensorflow::BlockingCounter;
|
||||||
|
|
||||||
|
// This same function appears in nccl_all_reduce_thunk. I've copy/pasted it
|
||||||
|
// here primarily because I want the VLOGs to work.
|
||||||
|
template <typename DescFn>
|
||||||
|
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
|
||||||
|
const DescFn& desc_fn) {
|
||||||
|
VLOG(3) << "Begin: " << desc_fn();
|
||||||
|
const std::chrono::milliseconds timeout(5000);
|
||||||
|
bool ok = counter->WaitFor(timeout);
|
||||||
|
if (ok) {
|
||||||
|
VLOG(3) << "Finished: " << desc_fn();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
|
||||||
|
<< "ms for and may be stuck: " << desc_fn();
|
||||||
|
counter->Wait();
|
||||||
|
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
|
||||||
|
"Perhaps the timeout is too short: "
|
||||||
|
<< desc_fn();
|
||||||
|
}
|
||||||
|
|
||||||
// Key for looking up a Rendezvous object in our global map.
|
// Key for looking up a Rendezvous object in our global map.
|
||||||
//
|
//
|
||||||
// Morally, the key is just a RunId. num_participants is in this struct only
|
// Morally, the key is just a RunId. num_participants is in this struct only
|
||||||
@ -48,6 +69,11 @@ struct RendezvousKey {
|
|||||||
RunId run_id;
|
RunId run_id;
|
||||||
int64 num_participants;
|
int64 num_participants;
|
||||||
|
|
||||||
|
string ToString() const {
|
||||||
|
return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
|
||||||
|
run_id.ToString(), num_participants);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename H>
|
template <typename H>
|
||||||
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
||||||
return H::combine(std::move(h), k.run_id);
|
return H::combine(std::move(h), k.run_id);
|
||||||
@ -82,11 +108,11 @@ struct ParticipantData {
|
|||||||
// Rendezvous objects can only be used once.
|
// Rendezvous objects can only be used once.
|
||||||
class Rendezvous {
|
class Rendezvous {
|
||||||
public:
|
public:
|
||||||
explicit Rendezvous(int64 num_participants)
|
explicit Rendezvous(const RendezvousKey& key)
|
||||||
: num_participants_(num_participants),
|
: key_(key),
|
||||||
all_arrived_(num_participants),
|
all_arrived_(key.num_participants),
|
||||||
returned_blocking_counter_(
|
returned_blocking_counter_(
|
||||||
std::make_shared<BlockingCounter>(num_participants)) {}
|
std::make_shared<BlockingCounter>(key.num_participants)) {}
|
||||||
|
|
||||||
// Runs the collective permute on the given thread.
|
// Runs the collective permute on the given thread.
|
||||||
//
|
//
|
||||||
@ -98,7 +124,7 @@ class Rendezvous {
|
|||||||
ParticipantData participant);
|
ParticipantData participant);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int64 num_participants_;
|
const RendezvousKey key_;
|
||||||
BlockingCounter all_arrived_;
|
BlockingCounter all_arrived_;
|
||||||
|
|
||||||
// BlockingCounter returned by SubmitParticipant.
|
// BlockingCounter returned by SubmitParticipant.
|
||||||
@ -146,7 +172,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
|
|||||||
if (primary) {
|
if (primary) {
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
returned_blocking_counter_ =
|
returned_blocking_counter_ =
|
||||||
std::make_shared<BlockingCounter>(num_participants_);
|
std::make_shared<BlockingCounter>(key_.num_participants);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,7 +181,13 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
|
|||||||
// everyone, then we wouldn't be able to enqueue the copies at the correct
|
// everyone, then we wouldn't be able to enqueue the copies at the correct
|
||||||
// point in their streams.
|
// point in their streams.
|
||||||
all_arrived_.DecrementCount();
|
all_arrived_.DecrementCount();
|
||||||
all_arrived_.Wait();
|
WaitAndLogIfStuck(&all_arrived_, [&] {
|
||||||
|
return absl::StrFormat(
|
||||||
|
"participant for replica %d (stream %p, device %d) waiting for all "
|
||||||
|
"other participants to arrive: %s",
|
||||||
|
participant.replica_id, participant.stream,
|
||||||
|
participant.stream->parent()->device_ordinal(), key_.ToString());
|
||||||
|
});
|
||||||
|
|
||||||
// Schedule the copies between the devices. This is much easier to reason
|
// Schedule the copies between the devices. This is much easier to reason
|
||||||
// about if we schedule all of the copies from just one thread. The copies
|
// about if we schedule all of the copies from just one thread. The copies
|
||||||
@ -185,7 +217,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
|
|||||||
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||||
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
|
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
|
||||||
[](const RendezvousKey& key) {
|
[](const RendezvousKey& key) {
|
||||||
return absl::make_unique<Rendezvous>(key.num_participants);
|
return absl::make_unique<Rendezvous>(key);
|
||||||
});
|
});
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
@ -245,7 +277,13 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
// erase() is deceptively complex to implement correctly.
|
// erase() is deceptively complex to implement correctly.
|
||||||
rendezvous.reset();
|
rendezvous.reset();
|
||||||
final_sync->DecrementCount();
|
final_sync->DecrementCount();
|
||||||
final_sync->Wait();
|
WaitAndLogIfStuck(final_sync.get(), [&] {
|
||||||
|
return absl::StrFormat(
|
||||||
|
"participant for replica %d (stream %p, device ordinal %d) waiting for "
|
||||||
|
"all threads to drop their reference to the rendezvous: %s",
|
||||||
|
participant.replica_id, participant.stream,
|
||||||
|
participant.stream->parent()->device_ordinal(), key.ToString());
|
||||||
|
});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
#include <chrono> // NOLINT (required by TF interfaces)
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -114,6 +115,24 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line,
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
template <typename DescFn>
|
||||||
|
void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
|
||||||
|
const DescFn& desc_fn) {
|
||||||
|
VLOG(3) << "Begin: " << desc_fn();
|
||||||
|
const std::chrono::milliseconds timeout(5000);
|
||||||
|
bool ok = counter->WaitFor(timeout);
|
||||||
|
if (ok) {
|
||||||
|
VLOG(3) << "Finished: " << desc_fn();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LOG(ERROR) << "This thread has been waiting for " << timeout.count()
|
||||||
|
<< "ms for and may be stuck: " << desc_fn();
|
||||||
|
counter->Wait();
|
||||||
|
LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
|
||||||
|
"Perhaps the timeout is too short: "
|
||||||
|
<< desc_fn();
|
||||||
|
}
|
||||||
|
|
||||||
// RAII class owning a ncclComm_t, ensuring it doesn't leak.
|
// RAII class owning a ncclComm_t, ensuring it doesn't leak.
|
||||||
class NcclComm {
|
class NcclComm {
|
||||||
public:
|
public:
|
||||||
@ -390,7 +409,8 @@ RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
|
|||||||
// Rendezvous objects can only be used once.
|
// Rendezvous objects can only be used once.
|
||||||
class Rendezvous {
|
class Rendezvous {
|
||||||
public:
|
public:
|
||||||
Rendezvous() = default;
|
explicit Rendezvous(const RendezvousKey& k)
|
||||||
|
: key_(k), all_participants_present_(k.participating_replicas.size()) {}
|
||||||
|
|
||||||
// Runs the all-reduce on the given thread. If successful, returns
|
// Runs the all-reduce on the given thread. If successful, returns
|
||||||
// - a handle to the clique that was used, so that the caller may keep the
|
// - a handle to the clique that was used, so that the caller may keep the
|
||||||
@ -405,8 +425,10 @@ class Rendezvous {
|
|||||||
private:
|
private:
|
||||||
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
|
Status DoAllReduce(ParticipantData participant, ncclComm_t comm);
|
||||||
|
|
||||||
|
const RendezvousKey key_;
|
||||||
|
tensorflow::BlockingCounter all_participants_present_;
|
||||||
|
|
||||||
tensorflow::mutex mu_;
|
tensorflow::mutex mu_;
|
||||||
tensorflow::condition_variable all_participants_present_;
|
|
||||||
|
|
||||||
bool initialized_ GUARDED_BY(mu_) = false;
|
bool initialized_ GUARDED_BY(mu_) = false;
|
||||||
absl::optional<tensorflow::BlockingCounter> done_;
|
absl::optional<tensorflow::BlockingCounter> done_;
|
||||||
@ -424,23 +446,14 @@ class Rendezvous {
|
|||||||
// Rendezvous objects are one-time use, so they're removed from this map once
|
// Rendezvous objects are one-time use, so they're removed from this map once
|
||||||
// we're through with them.
|
// we're through with them.
|
||||||
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||||
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>();
|
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
|
||||||
|
[](const RendezvousKey& k) { return absl::make_unique<Rendezvous>(k); });
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::shared_ptr<NcclClique>,
|
StatusOr<std::pair<std::shared_ptr<NcclClique>,
|
||||||
std::shared_ptr<tensorflow::BlockingCounter>>>
|
std::shared_ptr<tensorflow::BlockingCounter>>>
|
||||||
Rendezvous::SubmitParticipant(ParticipantData participant) {
|
Rendezvous::SubmitParticipant(ParticipantData participant) {
|
||||||
// We pull into our thread a) the communication handle and b) whether we're
|
|
||||||
// the "primary" thread for this rendezvous -- the "primary" thread has some
|
|
||||||
// additional responsibilities for setup/teardown.
|
|
||||||
ncclComm_t comm;
|
|
||||||
bool primary;
|
|
||||||
std::shared_ptr<NcclClique> clique;
|
|
||||||
|
|
||||||
// Releases the lock on the clique (held only by the primary thread).
|
|
||||||
Cleanup<std::function<void()>> clique_lock_releaser;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock lock(mu_);
|
tensorflow::mutex_lock lock(mu_);
|
||||||
CHECK(!initialized_);
|
CHECK(!initialized_);
|
||||||
@ -455,14 +468,29 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
|
|||||||
participants_.back().ToString(), participant.ToString());
|
participants_.back().ToString(), participant.ToString());
|
||||||
}
|
}
|
||||||
participants_.push_back(participant);
|
participants_.push_back(participant);
|
||||||
|
}
|
||||||
|
|
||||||
// Wait here for all participants to arrive.
|
// Wait for all participants to arrive.
|
||||||
while (participants_.size() < participant.num_participants()) {
|
all_participants_present_.DecrementCount();
|
||||||
all_participants_present_.wait(lock);
|
WaitAndLogIfStuck(&all_participants_present_, [&] {
|
||||||
}
|
return absl::StrFormat(
|
||||||
if (participants_.size() == participant.num_participants()) {
|
"participant for device ordinal %d, stream %p waiting for all "
|
||||||
all_participants_present_.notify_all();
|
"participants to be arrive at rendezvous %s",
|
||||||
}
|
participant.device_ordinal, participant.stream, key_.ToString());
|
||||||
|
});
|
||||||
|
|
||||||
|
// We pull into our thread a) the communication handle and b) whether we're
|
||||||
|
// the "primary" thread for this rendezvous -- the "primary" thread has some
|
||||||
|
// additional responsibilities for setup/teardown.
|
||||||
|
ncclComm_t comm;
|
||||||
|
bool primary;
|
||||||
|
std::shared_ptr<NcclClique> clique;
|
||||||
|
|
||||||
|
// Releases the lock on the clique (held only by the primary thread).
|
||||||
|
Cleanup<std::function<void()>> clique_lock_releaser;
|
||||||
|
|
||||||
|
{
|
||||||
|
tensorflow::mutex_lock lock(mu_);
|
||||||
|
|
||||||
// The first thread to get here has additional responsibilities, such as
|
// The first thread to get here has additional responsibilities, such as
|
||||||
// ensuring that there's a NCCL clique available for us to use.
|
// ensuring that there's a NCCL clique available for us to use.
|
||||||
@ -512,10 +540,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) {
|
|||||||
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
// The primary owns the lock on the NCCL clique. Hold it until all threads
|
||||||
// are done. (We'll release it when we return from this function.)
|
// are done. (We'll release it when we return from this function.)
|
||||||
if (primary) {
|
if (primary) {
|
||||||
VLOG(3)
|
WaitAndLogIfStuck(&*done_, [&] {
|
||||||
<< "Primary waiting for all participants to complete all-reduce op.";
|
return absl::StrFormat(
|
||||||
done_->Wait();
|
"primary participant (device ordinal %d, stream %p) waiting for all "
|
||||||
VLOG(3) << "All participants completed all-reduce op.";
|
"other participants to complete all-reduce %s",
|
||||||
|
participant.device_ordinal, participant.stream, key_.ToString());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(3) << "Returning status: " << all_reduce_status;
|
VLOG(3) << "Returning status: " << all_reduce_status;
|
||||||
@ -624,6 +654,7 @@ static StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||||
|
VLOG(1) << "Starting NcclAllReduceThunk.";
|
||||||
auto op_profiler =
|
auto op_profiler =
|
||||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||||
|
|
||||||
@ -642,6 +673,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
std::shared_ptr<Rendezvous> rendezvous =
|
std::shared_ptr<Rendezvous> rendezvous =
|
||||||
GlobalRendezvousMap()[rendezvous_key];
|
GlobalRendezvousMap()[rendezvous_key];
|
||||||
|
|
||||||
|
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
|
||||||
|
<< ", rendezvous: " << rendezvous.get()
|
||||||
|
<< ", participating replicas: "
|
||||||
|
<< absl::StrJoin(participating_replicas, ", ");
|
||||||
|
|
||||||
ParticipantData participant(rendezvous_key);
|
ParticipantData participant(rendezvous_key);
|
||||||
participant.element_count = element_count_;
|
participant.element_count = element_count_;
|
||||||
participant.device_ordinal = device_ordinal;
|
participant.device_ordinal = device_ordinal;
|
||||||
@ -682,7 +718,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
// erase() is deceptively complex to implement correctly.
|
// erase() is deceptively complex to implement correctly.
|
||||||
rendezvous.reset();
|
rendezvous.reset();
|
||||||
blocking_counter->DecrementCount();
|
blocking_counter->DecrementCount();
|
||||||
blocking_counter->Wait();
|
WaitAndLogIfStuck(blocking_counter.get(), [&] {
|
||||||
|
return absl::StrFormat(
|
||||||
|
"participant for device ordinal %d, stream %p waiting for "
|
||||||
|
"all threads to drop their reference to the rendezvous: %s",
|
||||||
|
device_ordinal, params.stream, rendezvous_key.ToString());
|
||||||
|
});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user