diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index e853b622676..4fb9b06d02f 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h" +#include // NOLINT (required by TF interfaces) #include #include #include @@ -40,6 +41,26 @@ namespace { using tensorflow::BlockingCounter; +// This same function appears in nccl_all_reduce_thunk. I've copy/pasted it +// here primarily because I want the VLOGs to work. +template +void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter, + const DescFn& desc_fn) { + VLOG(3) << "Begin: " << desc_fn(); + const std::chrono::milliseconds timeout(5000); + bool ok = counter->WaitFor(timeout); + if (ok) { + VLOG(3) << "Finished: " << desc_fn(); + return; + } + LOG(ERROR) << "This thread has been waiting for " << timeout.count() + << "ms for and may be stuck: " << desc_fn(); + counter->Wait(); + LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. " + "Perhaps the timeout is too short: " + << desc_fn(); +} + // Key for looking up a Rendezvous object in our global map. // // Morally, the key is just a RunId. num_participants is in this struct only @@ -48,6 +69,11 @@ struct RendezvousKey { RunId run_id; int64 num_participants; + string ToString() const { + return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}", + run_id.ToString(), num_participants); + } + template friend H AbslHashValue(H h, const RendezvousKey& k) { return H::combine(std::move(h), k.run_id); @@ -82,11 +108,11 @@ struct ParticipantData { // Rendezvous objects can only be used once. class Rendezvous { public: - explicit Rendezvous(int64 num_participants) - : num_participants_(num_participants), - all_arrived_(num_participants), + explicit Rendezvous(const RendezvousKey& key) + : key_(key), + all_arrived_(key.num_participants), returned_blocking_counter_( - std::make_shared(num_participants)) {} + std::make_shared(key.num_participants)) {} // Runs the collective permute on the given thread. // @@ -98,7 +124,7 @@ class Rendezvous { ParticipantData participant); private: - const int64 num_participants_; + const RendezvousKey key_; BlockingCounter all_arrived_; // BlockingCounter returned by SubmitParticipant. @@ -146,7 +172,7 @@ StatusOr> Rendezvous::SubmitParticipant( if (primary) { initialized_ = true; returned_blocking_counter_ = - std::make_shared(num_participants_); + std::make_shared(key_.num_participants); } } @@ -155,7 +181,13 @@ StatusOr> Rendezvous::SubmitParticipant( // everyone, then we wouldn't be able to enqueue the copies at the correct // point in their streams. all_arrived_.DecrementCount(); - all_arrived_.Wait(); + WaitAndLogIfStuck(&all_arrived_, [&] { + return absl::StrFormat( + "participant for replica %d (stream %p, device %d) waiting for all " + "other participants to arrive: %s", + participant.replica_id, participant.stream, + participant.stream->parent()->device_ordinal(), key_.ToString()); + }); // Schedule the copies between the devices. This is much easier to reason // about if we schedule all of the copies from just one thread. The copies @@ -185,7 +217,7 @@ StatusOr> Rendezvous::SubmitParticipant( RefcountingHashMap& GlobalRendezvousMap() { static auto& m = *new RefcountingHashMap( [](const RendezvousKey& key) { - return absl::make_unique(key.num_participants); + return absl::make_unique(key); }); return m; } @@ -245,7 +277,13 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // erase() is deceptively complex to implement correctly. rendezvous.reset(); final_sync->DecrementCount(); - final_sync->Wait(); + WaitAndLogIfStuck(final_sync.get(), [&] { + return absl::StrFormat( + "participant for replica %d (stream %p, device ordinal %d) waiting for " + "all threads to drop their reference to the rendezvous: %s", + participant.replica_id, participant.stream, + participant.stream->parent()->device_ordinal(), key.ToString()); + }); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 5674d400558..2e1bee65388 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #if GOOGLE_CUDA +#include // NOLINT (required by TF interfaces) #include #include #include @@ -114,6 +115,24 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line, } \ } while (0) +template +void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter, + const DescFn& desc_fn) { + VLOG(3) << "Begin: " << desc_fn(); + const std::chrono::milliseconds timeout(5000); + bool ok = counter->WaitFor(timeout); + if (ok) { + VLOG(3) << "Finished: " << desc_fn(); + return; + } + LOG(ERROR) << "This thread has been waiting for " << timeout.count() + << "ms for and may be stuck: " << desc_fn(); + counter->Wait(); + LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. " + "Perhaps the timeout is too short: " + << desc_fn(); +} + // RAII class owning a ncclComm_t, ensuring it doesn't leak. class NcclComm { public: @@ -390,7 +409,8 @@ RefcountingHashMap& GlobalNcclCliqueMap() { // Rendezvous objects can only be used once. class Rendezvous { public: - Rendezvous() = default; + explicit Rendezvous(const RendezvousKey& k) + : key_(k), all_participants_present_(k.participating_replicas.size()) {} // Runs the all-reduce on the given thread. If successful, returns // - a handle to the clique that was used, so that the caller may keep the @@ -405,8 +425,10 @@ class Rendezvous { private: Status DoAllReduce(ParticipantData participant, ncclComm_t comm); + const RendezvousKey key_; + tensorflow::BlockingCounter all_participants_present_; + tensorflow::mutex mu_; - tensorflow::condition_variable all_participants_present_; bool initialized_ GUARDED_BY(mu_) = false; absl::optional done_; @@ -424,23 +446,14 @@ class Rendezvous { // Rendezvous objects are one-time use, so they're removed from this map once // we're through with them. RefcountingHashMap& GlobalRendezvousMap() { - static auto& m = *new RefcountingHashMap(); + static auto& m = *new RefcountingHashMap( + [](const RendezvousKey& k) { return absl::make_unique(k); }); return m; } StatusOr, std::shared_ptr>> Rendezvous::SubmitParticipant(ParticipantData participant) { - // We pull into our thread a) the communication handle and b) whether we're - // the "primary" thread for this rendezvous -- the "primary" thread has some - // additional responsibilities for setup/teardown. - ncclComm_t comm; - bool primary; - std::shared_ptr clique; - - // Releases the lock on the clique (held only by the primary thread). - Cleanup> clique_lock_releaser; - { tensorflow::mutex_lock lock(mu_); CHECK(!initialized_); @@ -455,14 +468,29 @@ Rendezvous::SubmitParticipant(ParticipantData participant) { participants_.back().ToString(), participant.ToString()); } participants_.push_back(participant); + } - // Wait here for all participants to arrive. - while (participants_.size() < participant.num_participants()) { - all_participants_present_.wait(lock); - } - if (participants_.size() == participant.num_participants()) { - all_participants_present_.notify_all(); - } + // Wait for all participants to arrive. + all_participants_present_.DecrementCount(); + WaitAndLogIfStuck(&all_participants_present_, [&] { + return absl::StrFormat( + "participant for device ordinal %d, stream %p waiting for all " + "participants to be arrive at rendezvous %s", + participant.device_ordinal, participant.stream, key_.ToString()); + }); + + // We pull into our thread a) the communication handle and b) whether we're + // the "primary" thread for this rendezvous -- the "primary" thread has some + // additional responsibilities for setup/teardown. + ncclComm_t comm; + bool primary; + std::shared_ptr clique; + + // Releases the lock on the clique (held only by the primary thread). + Cleanup> clique_lock_releaser; + + { + tensorflow::mutex_lock lock(mu_); // The first thread to get here has additional responsibilities, such as // ensuring that there's a NCCL clique available for us to use. @@ -512,10 +540,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) { // The primary owns the lock on the NCCL clique. Hold it until all threads // are done. (We'll release it when we return from this function.) if (primary) { - VLOG(3) - << "Primary waiting for all participants to complete all-reduce op."; - done_->Wait(); - VLOG(3) << "All participants completed all-reduce op."; + WaitAndLogIfStuck(&*done_, [&] { + return absl::StrFormat( + "primary participant (device ordinal %d, stream %p) waiting for all " + "other participants to complete all-reduce %s", + participant.device_ordinal, participant.stream, key_.ToString()); + }); } VLOG(3) << "Returning status: " << all_reduce_status; @@ -624,6 +654,7 @@ static StatusOr> GetParticipatingReplicas( } Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(1) << "Starting NcclAllReduceThunk."; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); @@ -642,6 +673,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { std::shared_ptr rendezvous = GlobalRendezvousMap()[rendezvous_key]; + VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString() + << ", rendezvous: " << rendezvous.get() + << ", participating replicas: " + << absl::StrJoin(participating_replicas, ", "); + ParticipantData participant(rendezvous_key); participant.element_count = element_count_; participant.device_ordinal = device_ordinal; @@ -682,7 +718,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // erase() is deceptively complex to implement correctly. rendezvous.reset(); blocking_counter->DecrementCount(); - blocking_counter->Wait(); + WaitAndLogIfStuck(blocking_counter.get(), [&] { + return absl::StrFormat( + "participant for device ordinal %d, stream %p waiting for " + "all threads to drop their reference to the rendezvous: %s", + device_ordinal, params.stream, rendezvous_key.ToString()); + }); return Status::OK(); }