diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h
index 7b79ea925b0..b92aab90638 100644
--- a/tensorflow/compiler/xla/refcounting_hash_map.h
+++ b/tensorflow/compiler/xla/refcounting_hash_map.h
@@ -59,18 +59,6 @@ class RefcountingHashMap {
   std::shared_ptr<V> GetOrCreateIfAbsent(
       const K& key,
       const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
-    return *GetOrTryCreateIfAbsent(key, [&](const K& key) {
-      return StatusOr<std::unique_ptr<V>>(value_factory(key));
-    });
-  }
-
-  // Gets the value for the given key.
-  //
-  // If the map doesn't contain a live value for the key, constructs one
-  // using `value_factory`, or returns the status from `value_factory`.
-  StatusOr<std::shared_ptr<V>> GetOrTryCreateIfAbsent(
-      const K& key, const std::function<StatusOr<std::unique_ptr<V>>(const K&)>&
-                        value_factory) {
     absl::MutexLock lock(&mu_);
     auto it = map_.find(key);
     if (it != map_.end()) {
@@ -83,28 +71,13 @@ class RefcountingHashMap {
 
     // Create entry in the map and then set its value, so the value can
     // contain a pointer back into the map.
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<V> value_unique, value_factory(key));
     it = map_.emplace(key, std::weak_ptr<V>()).first;
-    std::shared_ptr<V> value(value_unique.release(), Deleter{it->first, *this});
+    std::shared_ptr<V> value(value_factory(key).release(),
+                             Deleter{it->first, *this});
     it->second = value;  // Set the weak ptr to the shared ptr.
     return value;
   }
 
-  // Runs a function over every key/value in the map.
-  //
-  // Touching the map from within this function may deadlock; don't do it.
-  //
-  // Function signature must be compatible with
-  //   void fn(const K&, std::shared_ptr<V>)
-  //
-  template <typename Fn>
-  void ForEach(Fn&& fn) {
-    absl::MutexLock lock(&mu_);
-    for (const auto& kv : map_) {
-      fn(kv.first, kv.second.lock());
-    }
-  }
-
  private:
   struct Deleter {
     const K& key;  // Points into parent->map_.
diff --git a/tensorflow/compiler/xla/refcounting_hash_map_test.cc b/tensorflow/compiler/xla/refcounting_hash_map_test.cc
index 8ead034d1bc..91a5cf89e72 100644
--- a/tensorflow/compiler/xla/refcounting_hash_map_test.cc
+++ b/tensorflow/compiler/xla/refcounting_hash_map_test.cc
@@ -81,43 +81,5 @@ TEST(RefcountingHashMapTest, CustomFactory) {
   EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
 }
 
-TEST(RefcountingHashMapTest, TrySuccessful) {
-  RefcountingHashMap<int, int> m;
-  auto factory = [](const int&) { return absl::make_unique<int>(7); };
-  StatusOr<std::shared_ptr<int>> result = m.GetOrTryCreateIfAbsent(42, factory);
-  ASSERT_TRUE(result.ok());
-  EXPECT_EQ(**result, 7);
-}
-
-TEST(RefcountingHashMapTest, TryFailure) {
-  RefcountingHashMap<int, int> m;
-  Status status = tensorflow::errors::Internal("Arrggg!");
-  auto factory = [&](const int&) { return status; };
-  EXPECT_EQ(m.GetOrTryCreateIfAbsent(42, factory).status(), status);
-}
-
-TEST(RefcountingHashMapTest, ForEachEmpty) {
-  RefcountingHashMap<int, int> m;
-  int64 count = 0;
-  m.ForEach([&](const int&, std::shared_ptr<int>) { ++count; });
-  EXPECT_EQ(count, 0);
-}
-
-TEST(RefcountingHashMapTest, ForEachNonempty) {
-  RefcountingHashMap<int, int> m;
-  auto factory = [](const int&) { return absl::make_unique<int>(); };
-  auto a = m.GetOrCreateIfAbsent(0, factory);
-  auto b = m.GetOrCreateIfAbsent(1, factory);
-
-  std::vector<int> seen_keys;
-  std::vector<int*> seen_values;
-  m.ForEach([&](const int& k, std::shared_ptr<int> v) {
-    seen_keys.push_back(k);
-    seen_values.push_back(v.get());
-  });
-  EXPECT_THAT(seen_keys, testing::UnorderedElementsAre(0, 1));
-  EXPECT_THAT(seen_values, testing::UnorderedElementsAre(a.get(), b.get()));
-}
-
 }  // anonymous namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
index 7174eefcfb9..f49561f888d 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc
@@ -115,16 +115,12 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
       AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
                         local_participants, params.nccl_unique_id_callback));
   ncclComm_t comm =
-      locked_clique.clique->GetCommForDeviceOrdinal(device_ordinal);
+      locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal);
 
   se::StreamExecutor* executor = params.stream->parent();
   se::gpu::ScopedActivateExecutorContext scoped_context(executor);
 
   TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
-
-  // Keep the clique we used alive for as long as this thunk lives.
-  absl::MutexLock lock(&mu_);
-  cliques_.insert(std::move(locked_clique.clique));
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
index 51720f91bd1..004a2c27678 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
@@ -93,19 +93,12 @@ class NcclCollectiveThunk : public Thunk {
   // error.
   static bool NcclIsEnabled();
 
-  Status ExecuteOnStream(const ExecuteParams& params) override
-      ABSL_LOCKS_EXCLUDED(mu_);
+  Status ExecuteOnStream(const ExecuteParams& params) override;
 
  protected:
   virtual Status RunNcclCollective(const ExecuteParams& params,
                                    ncclComm_t comm) = 0;
   virtual const NcclCollectiveConfig& config() const = 0;
-
- private:
-  // Creating NCCL cliques is expensive, so we cache them.
-  absl::Mutex mu_;
-  absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques_
-      ABSL_GUARDED_BY(mu_);
 };
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index b240138d771..8c892cb7907 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -105,16 +105,14 @@ ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
   return comms_by_device_ordinal_.at(device_ordinal).get();
 }
 
-RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache() {
-  // Global cache of NCCL cliques.  An entry in this map is kept alive as long
-  // as there's a reference to it somewhere.  A Thunk holds a reference to each
-  // Clique it's ever used.
+NcclCliqueMap& NcclCliqueCache() {
+  // Global cache of NCCL cliques.  An entry in this map is always kept alive.
   //
   // A consequence of the fact that this is process-global is that we'll only
   // ever have one clique alive for a given set of GPUs.  This means that a
   // process will never do two collective operations concurrently on the same
   // set of GPUs.
-  static auto& cache = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
+  static auto& cache = *new NcclCliqueMap();
   return cache;
 }
 
@@ -237,14 +235,13 @@ class NcclCliqueRendezvous
           });
       initialized_ = true;
     }
-    TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique, maybe_clique_);
+    TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
     std::unique_ptr<absl::MutexLock> clique_lock;
     if (primary) {
       clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
       counter_ = new absl::BlockingCounter(local_participants_.size());
     }
-    return LockedNcclClique(std::move(clique), std::move(clique_lock),
-                            counter_);
+    return LockedNcclClique(*clique, std::move(clique_lock), counter_);
   }
 
  private:
@@ -252,7 +249,7 @@ class NcclCliqueRendezvous
   const std::vector<LocalParticipant>& local_participants_;
   const NcclUniqueIdCallback* callback_;
 
-  StatusOr<std::shared_ptr<NcclClique>> maybe_clique_;
+  StatusOr<NcclClique*> maybe_clique_;
   absl::BlockingCounter* counter_;
 };
 
@@ -288,13 +285,13 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
   return local_participants;
 }
 
-LockedNcclClique::LockedNcclClique(std::shared_ptr<NcclClique> clique,
+LockedNcclClique::LockedNcclClique(NcclClique& clique,
                                    std::unique_ptr<absl::MutexLock> lock,
                                    absl::BlockingCounter* counter)
-    : clique(std::move(clique)), lock_(std::move(lock)), counter_(counter) {}
+    : clique(clique), lock_(std::move(lock)), counter_(counter) {}
 
 LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
-    : clique(std::move(other.clique)),
+    : clique(other.clique),
       lock_(std::move(other.lock_)),
       counter_(std::exchange(other.counter_, nullptr)) {}
 
@@ -308,6 +305,27 @@ LockedNcclClique::~LockedNcclClique() {
   }
 }
 
+StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
+    const NcclCliqueKey& key,
+    const std::function<StatusOr<std::unique_ptr<NcclClique>>(
+        const NcclCliqueKey&)>& value_factory) {
+  absl::MutexLock lock(&mu_);
+  auto it = map_.find(key);
+  if (it == map_.end()) {
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
+    it = map_.emplace(key, std::move(value)).first;
+  }
+  return it->second.get();
+}
+
+void NcclCliqueMap::ForEach(
+    const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
+  absl::MutexLock lock(&mu_);
+  for (const auto& kv : map_) {
+    fn(kv.first, *kv.second);
+  }
+}
+
 StatusOr<LockedNcclClique> AcquireNcclClique(
     const RendezvousKey& rendezvous_key, int local_device_ordinal,
     se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
index f24231d15d3..7662ce93df0 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
@@ -100,8 +100,6 @@ class NcclClique {
   absl::Mutex mu_;
 };
 
-RefcountingHashMap<NcclCliqueKey, NcclClique>& NcclCliqueCache();
-
 struct LocalParticipant {
   int device_ordinal;
   int rank;
@@ -113,13 +111,12 @@ StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
 
 class LockedNcclClique {
  public:
-  LockedNcclClique(std::shared_ptr<NcclClique> clique,
-                   std::unique_ptr<absl::MutexLock> lock,
+  LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock,
                    absl::BlockingCounter* counter);
   LockedNcclClique(LockedNcclClique&&);
   ~LockedNcclClique();
 
-  std::shared_ptr<NcclClique> clique;
+  NcclClique& clique;
 
  private:
   // Must come after clique, so it is destroyed first.
@@ -128,6 +125,27 @@ class LockedNcclClique {
   absl::BlockingCounter* counter_;
 };
 
+// Threadsafe leaky map from NcclCliqueKeys to NcclCliques.
+class NcclCliqueMap {
+ public:
+  StatusOr<NcclClique*> GetOrTryCreateIfAbsent(
+      const NcclCliqueKey& key,
+      const std::function<StatusOr<std::unique_ptr<NcclClique>>(
+          const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_);
+
+  // Runs a function over every key/value in the map.
+  void ForEach(
+      const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn)
+      ABSL_LOCKS_EXCLUDED(mu_);
+
+ private:
+  absl::Mutex mu_;
+  absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_
+      ABSL_GUARDED_BY(mu_);
+};
+
+NcclCliqueMap& NcclCliqueCache();
+
 // Acquires a locked NCCL clique for use in NCCL collective operations.
 StatusOr<LockedNcclClique> AcquireNcclClique(
     const RendezvousKey& rendezvous_key, int local_device_ordinal,