diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 4fb9b06d02f..60301b4de64 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -67,7 +67,7 @@ void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter, // because we use that information when constructing the Rendezvous. struct RendezvousKey { RunId run_id; - int64 num_participants; + int num_participants; // int, not int64, to match BlockingCounter's counter. string ToString() const { return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}", @@ -108,11 +108,7 @@ struct ParticipantData { // Rendezvous objects can only be used once. class Rendezvous { public: - explicit Rendezvous(const RendezvousKey& key) - : key_(key), - all_arrived_(key.num_participants), - returned_blocking_counter_( - std::make_shared(key.num_participants)) {} + explicit Rendezvous(const RendezvousKey& key) : key_(key) {} // Runs the collective permute on the given thread. // @@ -125,10 +121,11 @@ class Rendezvous { private: const RendezvousKey key_; - BlockingCounter all_arrived_; + BlockingCounter all_arrived_{key_.num_participants}; // BlockingCounter returned by SubmitParticipant. - std::shared_ptr returned_blocking_counter_; + std::shared_ptr returned_blocking_counter_{ + std::make_shared(key_.num_participants)}; tensorflow::mutex mu_; bool initialized_ GUARDED_BY(mu_) = false; diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 2e1bee65388..995a1fcd676 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -72,6 +72,8 @@ namespace gpu { #if GOOGLE_CUDA namespace { +using tensorflow::BlockingCounter; + // Functions to translate an ncclResult_t/cudaError_t to a Status object. Used // by the macros below. Status TranslateStatus(ncclResult_t s, const char* file, int64 line, @@ -116,8 +118,7 @@ Status TranslateStatus(cudaError_t s, const char* file, int64 line, } while (0) template -void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter, - const DescFn& desc_fn) { +void WaitAndLogIfStuck(BlockingCounter* counter, const DescFn& desc_fn) { VLOG(3) << "Begin: " << desc_fn(); const std::chrono::milliseconds timeout(5000); bool ok = counter->WaitFor(timeout); @@ -202,6 +203,8 @@ struct RendezvousKey { static_cast(instr->GetModule()->unique_id())); } + int num_participants() const { return participating_replicas.size(); } + template friend H AbslHashValue(H h, const RendezvousKey& k) { return H::combine(std::move(h), k.run_id, k.participating_replicas, @@ -248,9 +251,7 @@ struct ParticipantData { se::DeviceMemoryBase destination_data; se::Stream* stream; - int64 num_participants() const { - return rendezvous_key.participating_replicas.size(); - } + int num_participants() const { return rendezvous_key.num_participants(); } string ToString() const { return absl::StrFormat( @@ -409,8 +410,7 @@ RefcountingHashMap& GlobalNcclCliqueMap() { // Rendezvous objects can only be used once. class Rendezvous { public: - explicit Rendezvous(const RendezvousKey& k) - : key_(k), all_participants_present_(k.participating_replicas.size()) {} + explicit Rendezvous(const RendezvousKey& k) : key_(k) {} // Runs the all-reduce on the given thread. If successful, returns // - a handle to the clique that was used, so that the caller may keep the @@ -418,25 +418,26 @@ class Rendezvous { // - a BlockingCounter initialized to the number of participants, so that // the caller can coordinate with the participants one last time if it // chooses. This is useful for coordinating destruction of the Rendezvous. - StatusOr, - std::shared_ptr>> + StatusOr< + std::pair, std::shared_ptr>> SubmitParticipant(ParticipantData participant); private: Status DoAllReduce(ParticipantData participant, ncclComm_t comm); const RendezvousKey key_; - tensorflow::BlockingCounter all_participants_present_; + + BlockingCounter all_participants_present_{key_.num_participants()}; + BlockingCounter done_{key_.num_participants()}; + // BlockingCounter returned by SubmitParticipant. + std::shared_ptr returned_blocking_counter_{ + std::make_shared(key_.num_participants())}; tensorflow::mutex mu_; bool initialized_ GUARDED_BY(mu_) = false; - absl::optional done_; - std::vector participants_ GUARDED_BY(mu_); - // BlockingCounter returned by SubmitParticipant. Initialized by the primary - // thread. - std::shared_ptr returned_blocking_counter_; + std::vector participants_ GUARDED_BY(mu_); }; // Global map of Rendezvous objects. A thread participating in a collective op @@ -451,8 +452,8 @@ RefcountingHashMap& GlobalRendezvousMap() { return m; } -StatusOr, - std::shared_ptr>> +StatusOr< + std::pair, std::shared_ptr>> Rendezvous::SubmitParticipant(ParticipantData participant) { { tensorflow::mutex_lock lock(mu_); @@ -506,10 +507,6 @@ Rendezvous::SubmitParticipant(ParticipantData participant) { if (primary) { VLOG(3) << "Primary initializing accounting data."; initialized_ = true; - done_.emplace(participant.num_participants()); - returned_blocking_counter_ = - std::make_shared( - participant.num_participants()); // Acquire exclusive access to the NCCL clique itself so that two // unrelated collective operations won't try to use the clique @@ -535,12 +532,12 @@ Rendezvous::SubmitParticipant(ParticipantData participant) { Status all_reduce_status = DoAllReduce(participant, comm); VLOG(3) << "This thread done with all-reduce op."; - done_->DecrementCount(); + done_.DecrementCount(); // The primary owns the lock on the NCCL clique. Hold it until all threads // are done. (We'll release it when we return from this function.) if (primary) { - WaitAndLogIfStuck(&*done_, [&] { + WaitAndLogIfStuck(&done_, [&] { return absl::StrFormat( "primary participant (device ordinal %d, stream %p) waiting for all " "other participants to complete all-reduce %s",