diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index 1552397d4bc..b896264538d 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -236,7 +236,7 @@ class Rendezvous { "rendezvous: %p", rendezvous.get()); }); - return p.first; + return std::move(p.first); } protected: diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc index 6b6a9351384..b240138d771 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h" #include +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" @@ -221,7 +223,8 @@ class NcclCliqueRendezvous : Rendezvous(rendezvous_key), key_(std::move(rendezvous_key.global_devices)), local_participants_(local_participants), - callback_(callback) {} + callback_(callback), + counter_(nullptr) {} StatusOr RunCollectiveOp( const NcclCliqueParticipantData&) override { @@ -235,10 +238,13 @@ class NcclCliqueRendezvous initialized_ = true; } TF_ASSIGN_OR_RETURN(std::shared_ptr clique, maybe_clique_); + std::unique_ptr clique_lock; if (primary) { - lock_ = std::make_shared(clique->mu()); + clique_lock = std::make_unique(clique->mu()); + counter_ = new absl::BlockingCounter(local_participants_.size()); } - return LockedNcclClique{clique, lock_}; + return LockedNcclClique(std::move(clique), std::move(clique_lock), + counter_); } private: @@ -247,7 +253,7 @@ class NcclCliqueRendezvous const NcclUniqueIdCallback* callback_; StatusOr> maybe_clique_; - std::shared_ptr lock_; + absl::BlockingCounter* counter_; }; } // namespace @@ -282,6 +288,26 @@ StatusOr> GetLocalParticipants( return local_participants; } +LockedNcclClique::LockedNcclClique(std::shared_ptr clique, + std::unique_ptr lock, + absl::BlockingCounter* counter) + : clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {} + +LockedNcclClique::LockedNcclClique(LockedNcclClique&& other) + : clique(std::move(other.clique)), + lock_(std::move(other.lock_)), + counter_(std::exchange(other.counter_, nullptr)) {} + +LockedNcclClique::~LockedNcclClique() { + if (counter_) { + counter_->DecrementCount(); + if (lock_) { + counter_->Wait(); // Don't release lock until all threads are finished. + delete counter_; + } + } +} + StatusOr AcquireNcclClique( const RendezvousKey& rendezvous_key, int local_device_ordinal, se::Stream* stream, const std::vector& local_participants, diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h index 4a045d063e7..f24231d15d3 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" @@ -110,13 +111,21 @@ StatusOr> GetLocalParticipants( const std::vector& participants, const std::vector* local_devices); // may be null -struct LockedNcclClique { +class LockedNcclClique { + public: + LockedNcclClique(std::shared_ptr clique, + std::unique_ptr lock, + absl::BlockingCounter* counter); + LockedNcclClique(LockedNcclClique&&); + ~LockedNcclClique(); + std::shared_ptr clique; + + private: // Must come after clique, so it is destroyed first. - // This lock prevents other threads from using this clique. All of the threads - // involved should hold onto the lock until they have finished with their - // communicator. - std::shared_ptr lock; + // One thread holds a lock (it is null in the others). + std::unique_ptr lock_; + absl::BlockingCounter* counter_; }; // Acquires a locked NCCL clique for use in NCCL collective operations.