From 90ec1352d978c812d2f6c07050feeecb427d44c4 Mon Sep 17 00:00:00 2001 From: Loren Maggiore Date: Fri, 29 Jan 2021 04:08:22 -0800 Subject: [PATCH] Store NcclCliques in new NcclCliqueMap. PiperOrigin-RevId: 354505971 Change-Id: If7f52af22a37e125a587971abbe665f855335e37 --- .../compiler/xla/refcounting_hash_map.h | 31 +------------- .../compiler/xla/refcounting_hash_map_test.cc | 38 ----------------- .../xla/service/gpu/nccl_collective_thunk.cc | 6 +-- .../xla/service/gpu/nccl_collective_thunk.h | 9 +--- .../compiler/xla/service/gpu/nccl_utils.cc | 42 +++++++++++++------ .../compiler/xla/service/gpu/nccl_utils.h | 28 ++++++++++--- 6 files changed, 57 insertions(+), 97 deletions(-) diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h index 7b79ea925b0..b92aab90638 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map.h +++ b/tensorflow/compiler/xla/refcounting_hash_map.h @@ -59,18 +59,6 @@ class RefcountingHashMap { std::shared_ptr GetOrCreateIfAbsent( const K& key, const std::function(const K&)>& value_factory) { - return *GetOrTryCreateIfAbsent(key, [&](const K& key) { - return StatusOr>(value_factory(key)); - }); - } - - // Gets the value for the given key. - // - // If the map doesn't contain a live value for the key, constructs one - // using `value_factory`, or returns the status from `value_factory`. - StatusOr> GetOrTryCreateIfAbsent( - const K& key, const std::function>(const K&)>& - value_factory) { absl::MutexLock lock(&mu_); auto it = map_.find(key); if (it != map_.end()) { @@ -83,28 +71,13 @@ class RefcountingHashMap { // Create entry in the map and then set its value, so the value can // contain a pointer back into the map. - TF_ASSIGN_OR_RETURN(std::unique_ptr value_unique, value_factory(key)); it = map_.emplace(key, std::weak_ptr()).first; - std::shared_ptr value(value_unique.release(), Deleter{it->first, *this}); + std::shared_ptr value(value_factory(key).release(), + Deleter{it->first, *this}); it->second = value; // Set the weak ptr to the shared ptr. return value; } - // Runs a function over every key/value in the map. - // - // Touching the map from within this function may deadlock; don't do it. - // - // Function signature must be compatible with - // void fn(const K&, std::shared_ptr) - // - template - void ForEach(Fn&& fn) { - absl::MutexLock lock(&mu_); - for (const auto& kv : map_) { - fn(kv.first, kv.second.lock()); - } - } - private: struct Deleter { const K& key; // Points into parent->map_. diff --git a/tensorflow/compiler/xla/refcounting_hash_map_test.cc b/tensorflow/compiler/xla/refcounting_hash_map_test.cc index 8ead034d1bc..91a5cf89e72 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map_test.cc +++ b/tensorflow/compiler/xla/refcounting_hash_map_test.cc @@ -81,43 +81,5 @@ TEST(RefcountingHashMapTest, CustomFactory) { EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101); } -TEST(RefcountingHashMapTest, TrySuccessful) { - RefcountingHashMap m; - auto factory = [](const int&) { return absl::make_unique(7); }; - StatusOr> result = m.GetOrTryCreateIfAbsent(42, factory); - ASSERT_TRUE(result.ok()); - EXPECT_EQ(**result, 7); -} - -TEST(RefcountingHashMapTest, TryFailure) { - RefcountingHashMap m; - Status status = tensorflow::errors::Internal("Arrggg!"); - auto factory = [&](const int&) { return status; }; - EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status); -} - -TEST(RefcountingHashMapTest, ForEachEmpty) { - RefcountingHashMap m; - int64 count = 0; - m.ForEach([&](const int&, std::shared_ptr) { ++count; }); - EXPECT_EQ(count, 0); -} - -TEST(RefcountingHashMapTest, ForEachNonempty) { - RefcountingHashMap m; - auto factory = [](const int&) { return absl::make_unique(); }; - auto a = m.GetOrCreateIfAbsent(0, factory); - auto b = m.GetOrCreateIfAbsent(1, factory); - - std::vector seen_keys; - std::vector seen_values; - m.ForEach([&](const int& k, std::shared_ptr v) { - seen_keys.push_back(k); - seen_values.push_back(v.get()); - }); - EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1)); - EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get())); -} - } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc index 7174eefcfb9..f49561f888d 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc @@ -115,16 +115,12 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { AcquireNcclClique(rendezvous_key, device_ordinal, params.stream, local_participants, params.nccl_unique_id_callback)); ncclComm_t comm = - locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal); + locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal); se::StreamExecutor* executor = params.stream->parent(); se::gpu::ScopedActivateExecutorContext scoped_context(executor); TF_RETURN_IF_ERROR(RunNcclCollective(params, comm)); - - // Keep the clique we used alive for as long as this thunk lives. - absl::MutexLock lock(&mu_); - cliques_.insert(std::move(locked_clique.clique)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h index 51720f91bd1..004a2c27678 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h @@ -93,19 +93,12 @@ class NcclCollectiveThunk : public Thunk { // error. static bool NcclIsEnabled(); - Status ExecuteOnStream(const ExecuteParams& params) override - ABSL_LOCKS_EXCLUDED(mu_); + Status ExecuteOnStream(const ExecuteParams& params) override; protected: virtual Status RunNcclCollective(const ExecuteParams& params, ncclComm_t comm) = 0; virtual const NcclCollectiveConfig& config() const = 0; - - private: - // Creating NCCL cliques is expensive, so we cache them. - absl::Mutex mu_; - absl::flat_hash_set> cliques_ - ABSL_GUARDED_BY(mu_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc index b240138d771..8c892cb7907 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc @@ -105,16 +105,14 @@ ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const { return comms_by_device_ordinal_.at(device_ordinal).get(); } -RefcountingHashMap& NcclCliqueCache() { - // Global cache of NCCL cliques. An entry in this map is kept alive as long - // as there's a reference to it somewhere. A Thunk holds a reference to each - // Clique it's ever used. +NcclCliqueMap& NcclCliqueCache() { + // Global cache of NCCL cliques. An entry in this map is always kept alive. // // A consequence of the fact that this is process-global is that we'll only // ever have one clique alive for a given set of GPUs. This means that a // process will never do two collective operations concurrently on the same // set of GPUs. - static auto& cache = *new RefcountingHashMap(); + static auto& cache = *new NcclCliqueMap(); return cache; } @@ -237,14 +235,13 @@ class NcclCliqueRendezvous }); initialized_ = true; } - TF_ASSIGN_OR_RETURN(std::shared_ptr clique, maybe_clique_); + TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_); std::unique_ptr clique_lock; if (primary) { clique_lock = std::make_unique(clique->mu()); counter_ = new absl::BlockingCounter(local_participants_.size()); } - return LockedNcclClique(std::move(clique), std::move(clique_lock), - counter_); + return LockedNcclClique(*clique, std::move(clique_lock), counter_); } private: @@ -252,7 +249,7 @@ class NcclCliqueRendezvous const std::vector& local_participants_; const NcclUniqueIdCallback* callback_; - StatusOr> maybe_clique_; + StatusOr maybe_clique_; absl::BlockingCounter* counter_; }; @@ -288,13 +285,13 @@ StatusOr> GetLocalParticipants( return local_participants; } -LockedNcclClique::LockedNcclClique(std::shared_ptr clique, +LockedNcclClique::LockedNcclClique(NcclClique& clique, std::unique_ptr lock, absl::BlockingCounter* counter) - : clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {} + : clique(clique), lock_(std::move(lock)), counter_(counter) {} LockedNcclClique::LockedNcclClique(LockedNcclClique&& other) - : clique(std::move(other.clique)), + : clique(other.clique), lock_(std::move(other.lock_)), counter_(std::exchange(other.counter_, nullptr)) {} @@ -308,6 +305,27 @@ LockedNcclClique::~LockedNcclClique() { } } +StatusOr NcclCliqueMap::GetOrTryCreateIfAbsent( + const NcclCliqueKey& key, + const std::function>( + const NcclCliqueKey&)>& value_factory) { + absl::MutexLock lock(&mu_); + auto it = map_.find(key); + if (it == map_.end()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr value, value_factory(key)); + it = map_.emplace(key, std::move(value)).first; + } + return it->second.get(); +} + +void NcclCliqueMap::ForEach( + const std::function& fn) { + absl::MutexLock lock(&mu_); + for (const auto& kv : map_) { + fn(kv.first, *kv.second); + } +} + StatusOr AcquireNcclClique( const RendezvousKey& rendezvous_key, int local_device_ordinal, se::Stream* stream, const std::vector& local_participants, diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h index f24231d15d3..7662ce93df0 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h @@ -100,8 +100,6 @@ class NcclClique { absl::Mutex mu_; }; -RefcountingHashMap& NcclCliqueCache(); - struct LocalParticipant { int device_ordinal; int rank; @@ -113,13 +111,12 @@ StatusOr> GetLocalParticipants( class LockedNcclClique { public: - LockedNcclClique(std::shared_ptr clique, - std::unique_ptr lock, + LockedNcclClique(NcclClique& clique, std::unique_ptr lock, absl::BlockingCounter* counter); LockedNcclClique(LockedNcclClique&&); ~LockedNcclClique(); - std::shared_ptr clique; + NcclClique& clique; private: // Must come after clique, so it is destroyed first. @@ -128,6 +125,27 @@ class LockedNcclClique { absl::BlockingCounter* counter_; }; +// Threadsafe leaky map from NcclCliqueKeys to NcclCliques. +class NcclCliqueMap { + public: + StatusOr GetOrTryCreateIfAbsent( + const NcclCliqueKey& key, + const std::function>( + const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_); + + // Runs a function over every key/value in the map. + void ForEach( + const std::function& fn) + ABSL_LOCKS_EXCLUDED(mu_); + + private: + absl::Mutex mu_; + absl::flat_hash_map> map_ + ABSL_GUARDED_BY(mu_); +}; + +NcclCliqueMap& NcclCliqueCache(); + // Acquires a locked NCCL clique for use in NCCL collective operations. StatusOr AcquireNcclClique( const RendezvousKey& rendezvous_key, int local_device_ordinal,