diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 318d60b22df..9ba9ad61aa0 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -19,9 +19,9 @@ cc_library(
     deps = [
         ":tpu_compile_op_support",
         ":tpu_mesh_state_interface",
+        ":tpu_program_group_interface",
         ":tpu_util",
         ":tpu_util_hdrs",
-        "@com_google_absl//absl/types:span",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/jit:shape_inference",
         "//tensorflow/compiler/tf2xla:tf2xla_util",
@@ -30,16 +30,16 @@ cc_library(
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:compile_only_client",
-        "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        # "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
         "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
         "//tensorflow/core/tpu:tpu_configuration",
         "//tensorflow/core/tpu:tpu_defs",
         "//tensorflow/stream_executor/tpu:tpu_platform_interface",
+        "@com_google_absl//absl/types:span",
     ],
     alwayslink = 1,
 )
@@ -157,14 +157,28 @@ cc_library(
         "tpu_compilation_cache_entry.h",
     ],
     deps = [
+        ":compiled_subgraph",
+        ":tpu_compilation_cache_proto_cc",
         ":tpu_executable_info_proto_cc",
         ":tpu_program_group",
         "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "//tensorflow/core:framework",
         "//tensorflow/core/lib/core:refcount",
         "//tensorflow/core/platform:casts",
     ],
 )
 
+cc_library(
+    name = "tpu_compilation_cache_entry_impl",
+    srcs = [],
+    hdrs = ["tpu_compilation_cache_entry_impl.h"],
+    deps = [
+        ":compiled_subgraph",
+        ":tpu_compilation_cache_interface",
+        ":tpu_executable_info_proto_cc",
+    ],
+)
+
 cc_library(
     name = "tpu_compilation_cache_lookup",
     srcs = ["tpu_compilation_cache_lookup.cc"],
@@ -174,6 +188,7 @@ cc_library(
     deps = [
         ":tpu_compilation_cache_entry",
         ":tpu_compilation_cache_external",
+        ":tpu_compilation_cache_interface",
         ":tpu_compilation_cache_proto_cc",
         "//tensorflow/core/lib/core:refcount",
         "//tensorflow/core/platform:status",
@@ -247,6 +262,35 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "tpu_compilation_cache_interface",
+    srcs = ["tpu_compilation_cache_interface.cc"],
+    hdrs = ["tpu_compilation_cache_interface.h"],
+    deps = [
+        ":compiled_subgraph",
+        ":tpu_compilation_cache_key",
+        ":tpu_compilation_cache_metrics_hdrs",
+        ":tpu_compilation_cache_proto_cc",
+        ":tpu_util",
+        ":tpu_util_hdrs",
+        ":trace_util_hdrs",
+        "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_call",
+        "//tensorflow/core/platform:casts",  # buildcleaner: keep
+        "//tensorflow/core/profiler/lib:traceme",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
+    ],
+    alwayslink = 1,
+)
+
 cc_library(
     name = "tpu_compilation_cache_external",
     srcs = ["tpu_compilation_cache_external.cc"],
@@ -256,6 +300,8 @@ cc_library(
     deps = [
         ":compiled_subgraph",
         ":tpu_compilation_cache_entry",
+        ":tpu_compilation_cache_entry_impl",
+        ":tpu_compilation_cache_interface",
         ":tpu_compilation_cache_key",
         ":tpu_compilation_cache_metrics",  # buildcleaner: keep
         ":tpu_compilation_cache_metrics_hdrs",
@@ -355,6 +401,7 @@ cc_library(
         "//tensorflow/compiler/xla/client:compile_only_client",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/tpu/kernels/compiled_subgraph.h b/tensorflow/core/tpu/kernels/compiled_subgraph.h
index 1066e4839dd..a97c652c279 100644
--- a/tensorflow/core/tpu/kernels/compiled_subgraph.h
+++ b/tensorflow/core/tpu/kernels/compiled_subgraph.h
@@ -25,6 +25,9 @@ limitations under the License.
 namespace tensorflow {
 namespace tpu {
 
+// Forward declaration to avoid circular dependency.
+class TpuCompilationCacheInterface;
+
 // Cache for compiled TPU program.
 //
 // Each key identifies a unique subgraph, and the value is the vector of
@@ -100,10 +103,7 @@ namespace tpu {
 // unmarked and set to most recently used.
 //
 struct CompiledSubgraph : public core::RefCounted {
-  // TODO(henrytan): once `TpuCompilationCache` and
-  // `TpuCompilationCacheExternal` inherits from `TpuCompilationCacheInterface`
-  // update void* with `TpuCompilationCacheInterface`
-  void* parent = nullptr;  // Not owned.
+  TpuCompilationCacheInterface* parent = nullptr;  // Not owned.
 
   bool initialized = false;
 
@@ -145,7 +145,7 @@ struct CompiledSubgraph : public core::RefCounted {
   // owning main entry.
   CompiledSubgraph* main_entry = nullptr;
 
-  // Compiled Tpu program.
+  // Compiled TPU program group.
   std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
 
   // Computes total program size.
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
index 4d1f306ec0c..73f55853306 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
@@ -40,7 +40,7 @@ TpuCompilationCacheEntry::get_host_transfer_info() const {
 }
 
 const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const {
-  return tpu_program_group_->hlo_metadatas()[core_index_].get();
+  return tpu_program_group_->hlo_metadatas()[core_index_];
 }
 
 // TODO(henrytan,jiawenhao): When should we expect more than one
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
index a561fc51778..b3766b8b4dd 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
@@ -23,7 +23,7 @@ limitations under the License.
 namespace tensorflow {
 namespace tpu {
 
-// A version of `CompilationCacheEntry` that exposes Tpu binary program
+// A version of `CompilationCacheEntry` to access Tpu binary program
 // `XLA_TpuProgram`.
 class TpuCompilationCacheEntry {
  public:
@@ -42,28 +42,6 @@ class TpuCompilationCacheEntry {
   int core_index_;
 };
 
-// Base class for a reference to a cached proto. A unique_ptr to a
-// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
-// and ensures the underlying proto is not garbage-collected until the client
-// discards the ptr.
-class CompilationCacheEntryRef {
- public:
-  virtual ~CompilationCacheEntryRef() = default;
-
-  // Returns a CompilationCacheEntry that should not be used beyond the lifetime
-  // of the CompilationCacheEntryRef.
-  virtual TpuCompilationCacheEntry get() = 0;
-};
-
-// Base class that holds references to compiled protos so that the protos are
-// not garbage-collected before being used by execute ops. Use
-// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
-// ref holder object.
-class CompilationRefHolder : public ResourceBase {
- public:
-  ~CompilationRefHolder() override = default;
-};
-
 }  // namespace tpu
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h
new file mode 100644
index 00000000000..501f802b01f
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h
@@ -0,0 +1,108 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
+
+#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
+#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Wrapper for a cache entry that holds a reference to the entry until the
+// wrapper is deleted. This wrapper is the concrete type of
+// CompilationCacheEntryRef returned by Lookup.
+template <typename CacheEntryType>
+class CompilationCacheEntryRefImpl
+    : public CompilationCacheEntryRef<CacheEntryType> {
+ public:
+  CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
+                               CompiledSubgraph* entry, int index);
+
+  ~CompilationCacheEntryRefImpl() override;
+
+  Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
+
+ protected:
+  TpuCompilationCacheInterface* parent_;  // Not owned.
+  // A reference to entry_ is acquired in the constructor and released via
+  // parent->DiscardEntryRefs in the destructor.
+  CompiledSubgraph* entry_;
+  // The index of the program in entry_ that is returned by the get method.
+  int index_;
+};
+
+template <typename CacheEntryType>
+CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
+    TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
+    : parent_(parent), entry_(entry), index_(index) {
+  if (entry_ == nullptr) {
+    return;
+  }
+  if (entry_->main_entry == nullptr) {
+    entry_->Ref();
+  } else {
+    // This is a sharding/unsharding entry nested in a main entry. Only
+    // refcount the main entry.
+    entry_->main_entry->Ref();
+  }
+}
+
+template <typename CacheEntryType>
+CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
+  if (entry_ == nullptr) {
+    return;
+  }
+  if (entry_->main_entry == nullptr) {
+    parent_->DiscardEntryRefs({entry_});
+  } else {
+    parent_->DiscardEntryRefs({entry_->main_entry});
+  }
+}
+
+template <typename CacheEntryType>
+Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
+    CompilationCacheFetchTarget fetch_target) {
+  CompiledSubgraph* target = nullptr;
+  switch (fetch_target) {
+    case CompilationCacheFetchTarget::MAIN:
+      target = entry_;
+      break;
+    case CompilationCacheFetchTarget::SHARDING:
+      target = entry_->sharding_entry.get();
+      break;
+    case CompilationCacheFetchTarget::UNSHARDING:
+      target = entry_->unsharding_entry.get();
+      break;
+    default:
+      return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
+  }
+
+  if (target == nullptr) {
+    // Cache entry does not have an unsharding subentry. Unref and replace
+    // with nullptr.
+    parent_->DiscardEntryRefs({entry_});
+  }
+  // Otherwise, since the refcount is always on the main entry, we don't
+  // need ref/unref.
+  entry_ = target;
+  return Status::OK();
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
index 614dfbdf577..8cee90e8e55 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
@@ -50,14 +50,6 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
   entry->initialized = true;
 }
 
-std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
-  if (!key.has_guaranteed_const) {
-    return key.prefix;
-  }
-  return absl::StrCat(key.prefix, "|", key.session_handle, "|",
-                      key.guaranteed_const_fingerprint());
-}
-
 // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
 // data to compute the fingerprint.
 std::string GuaranteedConstFingerprint(
@@ -123,85 +115,32 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
 
 }  // namespace
 
-TpuCompilationCacheExternal::TpuCompilationCacheExternal(int64_t max_cache_size)
-    : max_cache_size_(max_cache_size) {
-  if (max_cache_size < 0) {
-    LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
-  }
-  VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
-}
+TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
+    TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
+    : CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
+                                                             index) {}
 
-TpuCompilationCacheExternal::~TpuCompilationCacheExternal() {
-  VLOG(1) << "TpuCompilationCacheExternal::~TpuCompilationCacheExternal()";
-  // A buggy client may be holding onto a reference, or a client might have
-  // crashed while holding onto a reference. In either case, discard all
-  // outstanding client references to avoid leaking storage.
-  for (const auto& entry : entries_by_uid_) {
-    while (entry.second->external_references > 0) {
-      TF_CHECK_OK(Release(entry.first));
-    }
+TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
+  if (entry_ == nullptr) {
+    // Create an empty entry if the entry is nullptr. This corresponds to
+    // non-existing sharding/unsharding entries.
+    return TpuCompilationCacheEntry();
   }
-  while (!entries_by_last_use_.empty()) {
-    UnloadAndDestroy(MarkOldestEntryForEviction());
-  }
-  // By the time the cache is deleted all reference holders should have already
-  // been deleted, since they were holding references to the cache. So all
-  // entries should be gone at this point.
-  CHECK_EQ(cache_store_.size(), 0);
-  CHECK_EQ(entries_by_uid_.size(), 0);
-  CHECK_EQ(entries_by_proto_key_.size(), 0);
-  CHECK_EQ(cache_size_, 0);
-  CHECK_EQ(marked_for_eviction_size_, 0);
-}
-
-std::string TpuCompilationCacheExternal::FindCacheKey(
-    const TpuCompilationCacheKey& subgraph_key) const {
-  if (!subgraph_key.has_guaranteed_const) {
-    return subgraph_key.prefix;
-  }
-  auto iter = session_key_map_.find(
-      strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
-  if (iter != session_key_map_.end()) {
-    return iter->second;
-  }
-  iter = fingerprint_key_map_.find(strings::StrCat(
-      subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
-  if (iter != session_key_map_.end()) {
-    return iter->second;
-  }
-  VLOG(1) << "No matching cache key found for key "
-          << ConstructCompilationCacheKey(subgraph_key);
-  return "";
-}
-
-void TpuCompilationCacheExternal::InsertEntry(
-    const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
-    CompiledSubgraph* entry) {
-  entry->parent = this;
-  entry->subgraph_key = cache_key;
-  entry->uid = get_uid();
-  TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
-  entry->cache_entry_debug_string = subgraph_key.prefix;
-  VLOG(1) << "Cache Initializing Entry Session Debug "
-          << entry->cache_entry_debug_string;
-
-  if (!subgraph_key.has_guaranteed_const) {
-    return;
-  }
-  session_key_map_.insert(std::make_pair(
-      strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
-      cache_key));
-  fingerprint_key_map_.insert(std::make_pair(
-      strings::StrCat(subgraph_key.prefix,
-                      subgraph_key.guaranteed_const_fingerprint()),
-      cache_key));
+  return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
 }
 
 CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
     const string& key,
-    const std::function<Status(TpuProgramGroup*)>& initialize_program,
+    const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
     const TpuCompilationCacheKey& subgraph_key) {
   CompiledSubgraph* main_entry = new CompiledSubgraph();
+  main_entry->parent = this;
+  main_entry->subgraph_key = key;
+  main_entry->uid = get_uid();
+  // TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
+  main_entry->cache_entry_debug_string = subgraph_key.prefix;
+  VLOG(1) << "Cache Initializing Entry Session Debug "
+          << main_entry->cache_entry_debug_string;
 
   // Add the entry to the cache, with size zero since there are no compiled
   // programs in it. Once the subgraph has been compiled,
@@ -212,7 +151,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
   // who created the entry. A second reference, owned by the cache, will be
   // added below since we leave the entry in the 'marked for eviction' state
   // here.
-  InsertEntry(key, subgraph_key, main_entry);
+  InsertEntry(key, main_entry);
 
   // Initialize the programs outside the lock so that other cache operations
   // can proceed during the (potentially lengthy) initialization.
@@ -320,470 +259,5 @@ TpuCompilationCacheExternal::CreateCompilationCacheKey(
   }
   return key;
 }
-
-TpuCompilationRefHolder* TpuCompilationCacheExternal::MakePerStepRefHolder() {
-  return new RefHolder(this);
-}
-
-Status TpuCompilationCacheExternal::MarkEntryForEviction(int64 subgraph_uid) {
-  profiler::TraceMe key_release_traceme(
-      "TPU compilation cache possibly evict uid",
-      /*level=*/2);
-  CompiledSubgraph* deleted_entry = nullptr;
-  {
-    absl::MutexLock lock(&mu_);
-    auto iter = entries_by_uid_.find(subgraph_uid);
-    if (iter == entries_by_uid_.end()) {
-      // If already evicted, return ok.
-      return Status::OK();
-    }
-
-    // Mark entry for eviction.
-    CompiledSubgraph* subgraph_to_evict = iter->second;
-    // If there are external references, should not use this API.
-    if (subgraph_to_evict->external_references != 0) {
-      return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
-                              " external_references greater than zero. Should "
-                              "use TpuCompilationCache::Release.");
-    }
-
-    VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
-    entries_by_last_use_.erase(subgraph_to_evict->last_use);
-    cache_size_ -= subgraph_to_evict->total_size;
-    marked_for_eviction_size_ += subgraph_to_evict->total_size;
-
-    // Evict if refcount exactly one, otherwise only discard cache's reference
-    // to the entry while the actual eviction will happen when refholder's
-    // references go away.
-    deleted_entry = DiscardEntryRef(subgraph_to_evict);
-
-    VLOG(1) << "After possibly evicting entry " << subgraph_uid
-            << " refs cache is " << cache_store_.size() << " entries ("
-            << cache_size_ + marked_for_eviction_size_
-            << " bytes), marked for eviction "
-            << (cache_store_.size() - entries_by_last_use_.size())
-            << " entries (" << marked_for_eviction_size_ << " bytes).";
-  }
-
-  // Unload from device cache if entry is evicted from host cache.
-  UnloadAndDestroy(deleted_entry);
-  return Status::OK();
-}
-
-Status TpuCompilationCacheExternal::Release(int64 subgraph_uid) {
-  profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
-                                        /*level=*/2);
-
-  CompiledSubgraph* deleted_entry = nullptr;
-  {
-    absl::MutexLock lock(&mu_);
-    auto iter = entries_by_uid_.find(subgraph_uid);
-
-    if (iter == entries_by_uid_.end()) {
-      return errors::NotFound("No cache entry found for uid ", subgraph_uid);
-    }
-
-    CHECK_GT(iter->second->external_references, 0);
-    --iter->second->external_references;
-
-    deleted_entry = DiscardEntryRef(iter->second);
-
-    VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
-            << cache_store_.size() << " entries ("
-            << cache_size_ + marked_for_eviction_size_
-            << " bytes), marked for eviction "
-            << (cache_store_.size() - entries_by_last_use_.size())
-            << " entries (" << marked_for_eviction_size_ << " bytes).";
-  }
-  UnloadAndDestroy(deleted_entry);
-  return Status::OK();
-}
-
-void TpuCompilationCacheExternal::UnloadAndDestroy(CompiledSubgraph* entry) {
-  if (!entry) return;
-
-  CHECK(entry->RefCountIsOne());
-  entry->tpu_program_group->UnloadAndDestroyPrograms();
-  entry->Unref();
-}
-
-size_t TpuCompilationCacheExternal::RemoveEntry(const string& key) {
-  auto erased = cache_store_.erase(key);
-  TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
-  auto parsed_key_or_status = ParseCompilationCacheKey(key);
-  CHECK(parsed_key_or_status.status().ok());
-  const TpuCompilationCacheKey parsed_key =
-      parsed_key_or_status.ConsumeValueOrDie();
-  if (!parsed_key.has_guaranteed_const) {
-    return erased;
-  }
-  session_key_map_.erase(
-      strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
-  fingerprint_key_map_.erase(strings::StrCat(
-      parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
-  return erased;
-}
-
-ABSL_MUST_USE_RESULT CompiledSubgraph*
-TpuCompilationCacheExternal::DiscardEntryRef(CompiledSubgraph* entry) {
-  if (entry->RefCountIsOne()) {
-    // The last reference to this entry is going away, so really delete it from
-    // the cache in such a way that it can't be restored by being looked up
-    // again.
-
-    // Sanity-check that it has been marked for eviction.
-    CHECK(entries_by_last_use_.find(entry->last_use) ==
-          entries_by_last_use_.end());
-    // Update the counter tracking how much space is taken up by entries that
-    // are marked for eviction.
-    marked_for_eviction_size_ -= entry->total_size;
-
-    // Remove the entry from the cache.
-    auto erased = RemoveEntry(entry->subgraph_key);
-
-    if (erased == 0) {
-      LOG(FATAL) << "Tried to discard nonexistent cache entry";
-    }
-    erased = entries_by_uid_.erase(entry->uid);
-    CHECK_EQ(erased, 1);
-    for (const string& key : entry->proto_key) {
-      erased = entries_by_proto_key_.erase(key);
-      CHECK_EQ(erased, 1);
-    }
-    // The actual deletion will happen outside the lock in UnloadAndDestroy().
-    return entry;
-  }
-  entry->Unref();
-  return nullptr;
-}
-
-void TpuCompilationCacheExternal::DiscardEntryRefs(
-    gtl::ArraySlice<CompiledSubgraph*> entries) {
-  std::vector<CompiledSubgraph*> removed_entries;
-  {
-    absl::MutexLock lock(&mu_);
-
-    for (auto entry : entries) {
-      removed_entries.push_back(DiscardEntryRef(entry));
-    }
-
-    VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
-            << " entries (" << cache_size_ + marked_for_eviction_size_
-            << " bytes), marked for eviction "
-            << (cache_store_.size() - entries_by_last_use_.size())
-            << " entries (" << marked_for_eviction_size_ << " bytes).";
-  }
-  for (auto removed_entry : removed_entries) {
-    UnloadAndDestroy(removed_entry);
-  }
-}
-
-ABSL_MUST_USE_RESULT CompiledSubgraph*
-TpuCompilationCacheExternal::MarkOldestEntryForEviction() {
-  CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
-  VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
-  entries_by_last_use_.erase(entry_to_mark->last_use);
-  cache_size_ -= entry_to_mark->total_size;
-  marked_for_eviction_size_ += entry_to_mark->total_size;
-  // Discard the cache's reference to entry. If steps are holding onto
-  // references to entry it won't be deleted until the last step holding it
-  // completes. It stays in the cache in the meantime and can be resurrected
-  // by a call to CompileIfKeyAbsent if that occurs before the last reference
-  // expires.
-  return DiscardEntryRef(entry_to_mark);
-}
-
-void TpuCompilationCacheExternal::LookupEntryMarkedForEviction(
-    CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
-  // The entry was previously marked for eviction (or is newly created) so
-  // unmark it. Add a reference (owned by the cache), update the cache size, and
-  // mark something old for eviction if necessary.
-  entry->Ref();
-  marked_for_eviction_size_ -= entry->total_size;
-  cache_size_ += entry->total_size;
-
-  // Mark the least-recently-used non-marked entry for eviction. Never mark the
-  // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
-  // which means there's only one entry not already marked for eviction), so
-  // that an entry persists in the cache even if it is larger than the allocated
-  // cache size.
-  while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
-    if (auto entry_to_evict = MarkOldestEntryForEviction()) {
-      removed_entries->push_back(entry_to_evict);
-    }
-  }
-}
-
-Status TpuCompilationCacheExternal::ToSubEntryRef(
-    CompilationCacheEntryRef* entry,
-    CompilationCacheFetchTarget fetch_target) const {
-  return static_cast<TpuEntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
-}
-
-TpuCompilationCacheExternal::TpuEntryRefImpl::TpuEntryRefImpl(
-    TpuCompilationCacheExternal* parent, CompiledSubgraph* entry, int index)
-    : parent_(parent), entry_(entry), index_(index) {
-  if (entry_ == nullptr) {
-    return;
-  }
-  if (entry_->main_entry == nullptr) {
-    entry_->Ref();
-  } else {
-    // This is a sharding/unsharding entry nested in a main entry. Only refcount
-    // the main entry.
-    entry_->main_entry->Ref();
-  }
-}
-
-TpuCompilationCacheExternal::TpuEntryRefImpl::~TpuEntryRefImpl() {
-  if (entry_ == nullptr) {
-    return;
-  }
-  if (entry_->main_entry == nullptr) {
-    parent_->DiscardEntryRefs({entry_});
-  } else {
-    parent_->DiscardEntryRefs({entry_->main_entry});
-  }
-}
-
-TpuCompilationCacheEntry TpuCompilationCacheExternal::TpuEntryRefImpl::get() {
-  if (entry_ == nullptr) {
-    // Create an empty entry if the entry is nullptr. This corresponds to
-    // non-existing sharding/unsharding entries.
-    return TpuCompilationCacheEntry();
-  }
-  return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
-}
-
-Status TpuCompilationCacheExternal::TpuEntryRefImpl::ToSubEntryRef(
-    CompilationCacheFetchTarget fetch_target) {
-  CompiledSubgraph* target = nullptr;
-  switch (fetch_target) {
-    case CompilationCacheFetchTarget::MAIN:
-      target = entry_;
-      break;
-    case CompilationCacheFetchTarget::SHARDING:
-      target = entry_->sharding_entry.get();
-      break;
-    case CompilationCacheFetchTarget::UNSHARDING:
-      target = entry_->unsharding_entry.get();
-      break;
-    default:
-      return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
-  }
-
-  if (target == nullptr) {
-    // Cache entry does not have an unsharding subentry. Unref and replace
-    // with nullptr.
-    parent_->DiscardEntryRefs({entry_});
-  }
-  // Otherwise, since the refcount is always on the main entry, we don't need
-  // ref/unref.
-  entry_ = target;
-  return Status::OK();
-}
-
-Status TpuCompilationCacheExternal::Lookup(
-    int64 uid, int proto_index,
-    std::unique_ptr<CompilationCacheEntryRef>* entry) {
-  entry->reset();
-
-  profiler::TraceMe proto_lookup_traceme(
-      "TPU compilation cache proto lookup by uid",
-      /*level=*/2);
-
-  absl::MutexLock lock(&mu_);
-  const auto iter = entries_by_uid_.find(uid);
-  if (iter == entries_by_uid_.end()) {
-    return errors::NotFound("No subgraph found for uid ", uid);
-  }
-  CompiledSubgraph* cache_entry = iter->second;
-  if (proto_index < 0 ||
-      proto_index >= cache_entry->tpu_program_group->program_size()) {
-    return errors::NotFound("No proto found for core index ", proto_index,
-                            " in subgraph with uid ", uid);
-  }
-  *entry = std::unique_ptr<CompilationCacheEntryRef>(
-      new TpuEntryRefImpl(this, cache_entry, proto_index));
-  return Status::OK();
-}
-
-Status TpuCompilationCacheExternal::Lookup(
-    const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
-  entry->reset();
-
-  profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
-                                         /*level=*/2);
-
-  absl::MutexLock lock(&mu_);
-  const auto iter = entries_by_proto_key_.find(proto_key);
-  if (iter == entries_by_proto_key_.end()) {
-    return errors::NotFound("No proto found for key ", proto_key);
-  }
-  CompiledSubgraph* cache_entry = iter->second.first;
-  int proto_index = iter->second.second;
-  *entry = std::unique_ptr<CompilationCacheEntryRef>(
-      new TpuEntryRefImpl(this, cache_entry, proto_index));
-  return Status::OK();
-}
-
-Status TpuCompilationCacheExternal::CompileIfKeyAbsentHelper(
-    const TpuCompilationCacheKey& subgraph_key,
-    const SessionMetadata* session_metadata,
-    TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
-    std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
-    std::vector<CompiledSubgraph*>* removed_entries,
-    std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
-    const std::function<Status(TpuProgramGroup*)>& compile_function) {
-  profiler::TraceMe subgraph_lookup_traceme(
-      "TPU compilation cache subgraph lookup",
-      /*level=*/2);
-
-  // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
-  // for the lifetime of the object, see InitializeEntry() call below.
-  absl::MutexLock lock(&mu_);
-
-  std::string cache_key = FindCacheKey(subgraph_key);
-  auto iter = cache_store_.find(cache_key);
-  bool is_new_key = iter == cache_store_.end();
-
-  const std::string session_name = SessionNameFromMetadata(session_metadata);
-
-  CompiledSubgraph* entry = nullptr;
-  if (is_new_key) {
-    cache_key = ConstructCompilationCacheKey(subgraph_key);
-    TpuCompilationCacheMetrics::IncrementCacheLookupCount(
-        /*is_cache_hit=*/false, session_name);
-    const string msg =
-        strings::StrCat("TPU host compilation cache miss: cache_key(",
-                        cache_key, "), session_name(", session_name, ")");
-
-    TRACESTRING(msg);
-    LOG(INFO) << msg;
-
-    // Check if caller has disabled compilation. Set using
-    // internal::ScopedTpuCompileDisabler.
-    if (!IsTpuCompilationEnabled()) {
-      const string error_msg = strings::StrCat(
-          "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
-          "disabled, session_name(",
-          session_name, ") Debug String: ", subgraph_key.debug_string);
-      if (VLOG_IS_ON(2)) {
-        VLOG(2) << "Cache Missed. Current cache entries: ";
-        for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
-          // TODO(henrytan): add DebugKey as cache_entry_debug_string to
-          // TpuCompilationCacheKey.
-          VLOG(2) << "Cache Debug Info: ";
-          VLOG(2) << it->second->cache_entry_debug_string;
-        }
-      }
-
-      LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
-      return errors::NotFound(error_msg);
-    }
-
-    // The single ref on the newly-created entry is owned by the caller.
-    VLOG(1) << "Before adding new entry for key " << cache_key
-            << " with session_name( " << session_name << ");"
-            << "; cache is " << cache_store_.size() << " entries ("
-            << cache_size_ + marked_for_eviction_size_ << " bytes), "
-            << " marked for eviction "
-            << (cache_store_.size() - entries_by_last_use_.size())
-            << " entries (" << marked_for_eviction_size_ << " bytes).";
-    // Note that InitializeEntry() will Release/Reacquire mu_.
-    entry = InitializeEntry(cache_key, compile_function, subgraph_key);
-    TRACELITERAL("TPU host compilation cache: compilation done.");
-
-    LOG(INFO) << strings::StrCat(
-        "TPU host compilation cache: compilation done for cache_key(",
-        cache_key, "), session_name(", session_name, ")");
-    // If session_name is present, log some additional stats related to HBM
-    // here, so that they can be associated directly to the session.
-    if (!session_name.empty()) {
-      entry->tpu_program_group->LogProgramMemorySummary();
-    }
-  } else {
-    TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
-    const string msg =
-        strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
-                        "), session_name(", session_name, ")");
-    TRACESTRING(msg);
-    VLOG(1) << msg;
-    VLOG(1) << "Before refreshing entry for key " << cache_key
-            << " with session_name( " << session_name << "); cache is "
-            << cache_store_.size() << " entries ("
-            << cache_size_ + marked_for_eviction_size_ << " bytes), "
-            << " marked for eviction "
-            << (cache_store_.size() - entries_by_last_use_.size())
-            << " entries (" << marked_for_eviction_size_ << " bytes).";
-    entry = iter->second;
-    // Make a new reference that is owned by the caller.
-    entry->Ref();
-    // Block if necessary until the subgraph has been initialized.
-    mu_.Await(absl::Condition(
-        +[](CompiledSubgraph* e) { return e->initialized; }, entry));
-  }
-
-  // Let the caller know the uid of the entry.
-  *uid = entry->uid;
-  // Let the caller know the keys for each of the cached protos.
-  *proto_key = entry->proto_key;
-  *may_modify_variables = entry->tpu_program_group->may_modify_variables();
-  *hlo_metadata = entry->tpu_program_group->hlo_metadatas();
-
-  // If the caller didn't supply a per_step_ref_holder then the caller is going
-  // to manually release the reference later via a call to Release().
-  if (per_step_ref_holder == nullptr) {
-    ++entry->external_references;
-  } else {
-    // The caller wants its reference to be handed off to a per-step holder that
-    // will discard the reference when the step completes.
-    RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
-    TF_RET_CHECK(cast_ref_holder != nullptr);
-    cast_ref_holder->AddRef(entry);
-  }
-
-  // Remove the old LRU-table entry if it wasn't already marked for eviction.
-  auto erased = entries_by_last_use_.erase(entry->last_use);
-  // Update the LRU table indicating this entry is the most recently used.
-  entry->last_use = use_counter_++;
-  entries_by_last_use_[entry->last_use] = entry;
-  if (erased == 0) {
-    // The entry had been marked for eviction, or is newly created.
-    LookupEntryMarkedForEviction(entry, removed_entries);
-  }
-
-  // Log a little more verbosely when a key is added.
-  if (VLOG_IS_ON(1) || is_new_key) {
-    LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
-              << " entry for key " << cache_key << " with session_name "
-              << session_name << " cache is " << cache_store_.size()
-              << " entries (" << cache_size_ + marked_for_eviction_size_
-              << " bytes), "
-              << " marked for eviction "
-              << (cache_store_.size() - entries_by_last_use_.size())
-              << " entries (" << marked_for_eviction_size_ << " bytes).";
-  }
-  return entry->initialization_status;
-}
-
-tensorflow::Status TpuCompilationCacheExternal::CompileIfKeyAbsent(
-    const TpuCompilationCacheKey& cache_key,
-    const tensorflow::SessionMetadata* session_metadata,
-    TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
-    std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
-    std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
-    const std::function<tensorflow::Status(TpuProgramGroup*)>&
-        compile_function) {
-  std::vector<CompiledSubgraph*> removed_entries;
-  auto status = CompileIfKeyAbsentHelper(
-      cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
-      may_modify_variables, &removed_entries, hlo_metadata, compile_function);
-  for (auto entry : removed_entries) {
-    UnloadAndDestroy(entry);
-  }
-  return status;
-}
-
 }  // namespace tpu
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
index eff2afde108..2c75cb4d053 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
@@ -26,11 +26,14 @@ limitations under the License.
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
 #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
@@ -40,37 +43,25 @@ limitations under the License.
 namespace tensorflow {
 namespace tpu {
 
-const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
-const char kCompilationCacheUnloaderResourceName[] =
+constexpr char kCompilationCacheResourceName[] = "tpu_compilation_cache";
+constexpr char kCompilationCacheUnloaderResourceName[] =
     "tpu_compilation_cache_unloader";
 
-// Base class that holds references to compiled protos so that the protos are
-// not garbage-collected before being used by execute ops. Use
-// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
-// ref holder object.
-class TpuCompilationRefHolder : public ResourceBase {
- public:
-  ~TpuCompilationRefHolder() override = default;
-};
-
-class TpuCompilationCacheExternal : public ResourceBase {
+class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
  public:
   using Status = ::stream_executor::port::Status;
 
-  explicit TpuCompilationCacheExternal(int64_t max_cache_size);
-  ~TpuCompilationCacheExternal() override;
-  TpuCompilationCacheExternal(const TpuCompilationCacheExternal&) = delete;
-  TpuCompilationCacheExternal& operator=(const TpuCompilationCacheExternal&) =
-      delete;
+  class EntryRefImpl
+      : public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
+   public:
+    EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
+                 int index);
 
-  Status CompileIfKeyAbsent(
-      const TpuCompilationCacheKey& cache_key,
-      const SessionMetadata* session_metadata,
-      TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
-      std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
-      std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
-      const std::function<tensorflow::Status(TpuProgramGroup*)>&
-          compile_function);
+    TpuCompilationCacheEntry get() override;
+  };
+
+  explicit TpuCompilationCacheExternal(int64 max_cache_size)
+      : TpuCompilationCacheInterface(max_cache_size) {}
 
   static TpuCompilationCacheKey CreateCompilationCacheKey(
       absl::string_view function_name, uint64 function_library_fingerprint,
@@ -82,177 +73,7 @@ class TpuCompilationCacheExternal : public ResourceBase {
 
   string DebugString() const override { return "TpuCompilationCacheExternal"; }
 
-  // Makes a reference holder for this cache, that can be stored in the per-step
-  // resource manager and will ensure that compiled entries persist until the
-  // end of a step.
-  TpuCompilationRefHolder* MakePerStepRefHolder();
-
-  // Differences between MarkEntryForEviction and Release:
-  // There are two modes of managing cache entries:
-  // 1) LRU eviction + pinning; 2) manual.
-  // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
-  // Otherwise it is manual mode (mainly used by XRT).
-  // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
-  // entries when callers know that they do not need them anymore.
-  // Release should only be used in mode 2) to explicitly remove an entry.
-
-  // Mark the entry indexed by `subgraph_uid` for eviction. This should only be
-  // called if per_step_ref_holder was NOT nullptr in the corresponding call to
-  // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
-  // subgraph_uid).
-  Status MarkEntryForEviction(int64 subgraph_uid);
-
-  // Manually discards a reference to the compiled subgraph. This should only be
-  // called if per_step_ref_holder was nullptr in the corresponding call to
-  // CompileIfKeyAbsent(subgraph_key, ...).
-  Status Release(int64 subgraph_uid);
-
-  // Looks up an executable corresponding to the model-parallel core index of
-  // the subgraph represented by key. On success a pointer to an EntryRef
-  // holding the program is returned in entry.
-  Status Lookup(const string& proto_key,
-                std::unique_ptr<CompilationCacheEntryRef>* entry);
-
-  // Looks up an executable corresponding to the model-parallel core index of
-  // the subgraph represented by uid. On success a pointer to an EntryRef
-  // holding the program is returned in entry.
-  Status Lookup(int64 uid, int proto_index,
-                std::unique_ptr<CompilationCacheEntryRef>* entry);
-
-  // Mutates the main entry ref to point to the entry's subentry
-  // (for sharding/unsharding) or main entry (unchanged) representing the
-  // fetch target. The entry ref needs to point to the main entry before this
-  // call.
-  //
-  // If the requested subentry does not exist, the ref will point to a nullptr
-  // entry.
-  Status ToSubEntryRef(CompilationCacheEntryRef* entry,
-                       CompilationCacheFetchTarget fetch_target) const;
-
  private:
-  // Wrapper for a cache entry that holds a reference to the entry until the
-  // wrapper is deleted. This wrapper is the concrete type of
-  // CompilationCacheEntryRef returned by Lookup.
-  class TpuEntryRefImpl : public CompilationCacheEntryRef {
-   public:
-    TpuEntryRefImpl(TpuCompilationCacheExternal* parent,
-                    CompiledSubgraph* entry, int index);
-    ~TpuEntryRefImpl() override;
-
-    TpuCompilationCacheEntry get() override;
-
-    // Mutates this ref to point to the entry's subentry (for
-    // sharding/unsharding) or main entry (unchanged) as specified by
-    // fetch_target. The refcount is kept unchanged, since we only track the
-    // refcount of the main entry. The entry ref needs to point to the main
-    // entry before this call.
-    //
-    // If the requested subentry does not exist, the ref will point to a nullptr
-    // entry, and the original entry will be unref'ed.
-    Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
-
-   private:
-    TpuCompilationCacheExternal* parent_;  // Not owned.
-    // A reference to entry_ is acquired in the constructor and released via
-    // parent->DiscardEntryRefs in the destructor.
-    CompiledSubgraph* entry_;
-    // The program in entry_ that is returned by the get method.
-    int index_;
-  };
-
-  // Private implementation of the generic CompilationRefHolder that knows about
-  // CompiledSubgraph entries.
-  class RefHolder : public TpuCompilationRefHolder {
-   public:
-    explicit RefHolder(TpuCompilationCacheExternal* parent) : parent_(parent) {
-      parent_->Ref();
-    }
-    ~RefHolder() override {
-      // Release our reference to the parent.
-      parent_->Unref();
-    }
-
-    // Adds entry to the list of entries that will be released when the
-    // RefHolder is destroyed. Each entry is released via a call to
-    // parent_->DiscardEntryRefs.
-    void AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); }
-
-    string DebugString() const override {
-      return "TpuCompilationCacheExternal::RefHolder";
-    }
-
-   private:
-    TpuCompilationCacheExternal* parent_;  // Not owned.
-    std::vector<CompiledSubgraph*> entries_;
-  };
-
-  // The bulk of implementation of CompileIfKeyAbsent() with the exception
-  // of unloading programs that corresponds to possibly removed cache
-  // entries. The split helps to manage locking since we prefer to perform
-  // unloading without holding extra locks.
-  Status CompileIfKeyAbsentHelper(
-      const TpuCompilationCacheKey& subgraph_key,
-      const SessionMetadata* session_metadata,
-      TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
-      std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
-      std::vector<CompiledSubgraph*>* removed_entries,
-      std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
-      const std::function<Status(TpuProgramGroup*)>& compile_function);
-
-  // This is called by the cache when entry is marked for eviction; by
-  // a RefHolder (via DiscardEntryRefs) when a step completes; and by
-  // an EntryRefImpl when it is destroyed. Releases one reference to entry
-  // if more than 1 remains. If only one reference is left, the entry is removed
-  // from cache_ and is returned to the caller; which must eventually call
-  // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
-  // to avoid holding the lock during program unloading.
-  ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
-      CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-  // Convenience method called by ~RefHolder without mu_ held. Calls
-  // DiscardEntryRef on every element of entries.
-  void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
-
-  // Marks the oldest unmarked entry for eviction. Requires that there is at
-  // least one such entry. In case the evicted entry had only 1 reference it
-  // is removed from the cache and returned to the caller which must eventually
-  // call UnloadAndDestroy.
-  CompiledSubgraph* MarkOldestEntryForEviction()
-      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Updates datastructures to indicate that entry, which had been marked for
-  // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
-  // entry is newly created, or an entry that has been marked for eviction but
-  // not yet evicted is looked up.
-  //
-  // First the entry is unmarked for eviction, i.e. the cache gains a reference
-  // to entry, entry's last_use field is set to be the most recent value of
-  // use_counter_ and entries_by_last_use_ is updated accordingly.
-  //
-  // Next, the size of the cache is examined to see if any other entries need to
-  // be marked for eviction now that entry has been unmarked. While the total
-  // size of unmarked cached entries is greater than max_cache_size_, entries
-  // are marked for eviction in LRU order. The most recently used entry is never
-  // marked for eviction, so an entry larger than the max cache size will remain
-  // in the cache until it is replaced by something else. In case some entries
-  // actually were removed from the cache, they are a returned to the caller via
-  // removed_entries. The caller must eventually delete them by calling
-  // UnloadAndDestroy.
-  void LookupEntryMarkedForEviction(
-      CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
-      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Removes the entry with given key from cache.
-  size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Inserts the given key and entry to cache.
-  void InsertEntry(const std::string& key,
-                   const TpuCompilationCacheKey& subgraph_key,
-                   CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Returns the cache key matching given subgraph_key.
-  std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
-      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
   // Creates a new entry by running initialize_programs and places it in the
   // cache to be looked up by key. The new entry is in the 'marked for eviction'
   // state (not present in entries_by_last_use_) and the caller is expected to
@@ -261,61 +82,10 @@ class TpuCompilationCacheExternal : public ResourceBase {
   // **InitializeEntry releases mu_ during the call to initialize_programs.**
   CompiledSubgraph* InitializeEntry(
       const string& key,
-      const std::function<Status(TpuProgramGroup*)>& initialize_program,
+      const std::function<Status(TpuProgramGroupInterface*)>&
+          initialize_program,
       const TpuCompilationCacheKey& subgraph_key)
-      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Unloads the program associated with the entry from all local devices
-  // and deletes the entry itself. It is assumed no one else has a reference
-  // to it and all related keys had already been removed from the cache.
-  // The call can perform device IO so no locks should be held while calling it.
-  void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
-
-  // The maximum size of entries that are stored in the cache before entries are
-  // marked for eviction.
-  const int64 max_cache_size_;
-
-  mutable absl::Mutex mu_;
-  // The total size of entries that are stored and not marked for eviction.
-  int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
-
-  // The total size of entries that are marked for eviction.
-  int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
-
-  // The value to assign to the last_use field of the next entry that is looked
-  // up.
-  int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
-
-  // session_key_map_ and fingerprint_key_map_ are used for looking up the
-  // cache_ key matching a given subgraph key. When doing a lookup, check
-  // session_key_map_ first to avoid unnecessay fingerprint computation.
-  // Map from key prefix + session_handle to a cache_ key.
-  std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
-
-  // Map from key prefix + fingerprint to a cache_ key.
-  std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
-
-  // All the subgraph entries that can be looked up in the cache. An entry is
-  // marked for eviction iff it is present in cache_ and not in
-  // entries_by_last_use_.
-  std::unordered_map<string, CompiledSubgraph*> cache_store_
-      ABSL_GUARDED_BY(mu_);
-
-  // All the subgraph entries that can be looked up in the cache, indexed by
-  // uid.
-  absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
-      ABSL_GUARDED_BY(mu_);
-
-  // All the protos that can be looked up in the cache, indexed by proto
-  // key. The value of the map is a subgraph and the index of the proto compiled
-  // for that subgraph.
-  std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
-      entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
-
-  // Map from last_use to entry, used to mark entries for eviction in LRU
-  // order. If an entry's last_use counter is not present as a key in
-  // entries_by_last_use_ then the entry has been marked for eviction.
-  std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override;
 };
 
 }  // namespace tpu
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
index f3e40df24dd..3b46f0f2d32 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
@@ -93,7 +93,9 @@ Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) {
                               "use TpuCompilationCacheInterface::Release.");
     }
 
-    VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
+    VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key
+            << " for eviction. Debug string: "
+            << subgraph_to_evict->cache_entry_debug_string;
     entries_by_last_use_.erase(subgraph_to_evict->last_use);
     cache_size_ -= subgraph_to_evict->total_size;
     marked_for_eviction_size_ += subgraph_to_evict->total_size;
@@ -231,7 +233,9 @@ void TpuCompilationCacheInterface::DiscardEntryRefs(
 
 CompiledSubgraph* TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
   CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
-  VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
+  VLOG(1) << "Marking " << entry_to_mark->subgraph_key
+          << " for eviction. Debug string: "
+          << entry_to_mark->cache_entry_debug_string;
   entries_by_last_use_.erase(entry_to_mark->last_use);
   cache_size_ -= entry_to_mark->total_size;
   marked_for_eviction_size_ += entry_to_mark->total_size;
@@ -291,7 +295,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
     const SessionMetadata* session_metadata,
     CompilationRefHolder* per_step_ref_holder, int64* uid,
     std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
-    std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
+    absl::Span<const xla::HloProto* const>* hlo_metadatas,
     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
   std::vector<CompiledSubgraph*> removed_entries;
   auto status = CompileIfKeyAbsentHelper(
@@ -328,7 +332,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
     CompilationRefHolder* per_step_ref_holder, int64* uid,
     std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
     std::vector<CompiledSubgraph*>* removed_entries,
-    std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
+    absl::Span<const xla::HloProto* const>* hlo_metadatas,
     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
   CompiledSubgraph* entry = nullptr;
 
@@ -388,7 +392,8 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
     TRACELITERAL("TPU host compilation cache: compilation done.");
     LOG(INFO) << strings::StrCat(
         "TPU host compilation cache: compilation done for cache_key(",
-        cache_key, "), session_name(", session_name, ")");
+        cache_key, "), session_name(", session_name, "), subgraph_key(",
+        subgraph_key.debug_string, ")");
     // If session_name is present, log some additional stats related to HBM
     // here, so that they can be associated directly to the session.
     if (!session_name.empty()) {
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
new file mode 100644
index 00000000000..f92893b78f6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
@@ -0,0 +1,355 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/node_hash_map.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
+#include "tensorflow/core/tpu/kernels/trace_util.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Base class that holds references to compiled protos so that the protos are
+// not garbage-collected before being used by execute ops. Use
+// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
+// ref holder object.
+class CompilationRefHolder : public ResourceBase {
+ public:
+  ~CompilationRefHolder() override = default;
+};
+
+// Base class for a reference to a cached tpu program. A unique_ptr to a
+// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
+// and ensures the underlying proto is not garbage-collected until the client
+// discards the ptr.
+template <typename CacheEntryType>
+class CompilationCacheEntryRef {
+ public:
+  virtual ~CompilationCacheEntryRef() = default;
+
+  // Returns a CompilationCacheEntry that should not be used beyond the lifetime
+  // of the tpu::CompilationCacheEntryRef.
+  virtual CacheEntryType get() = 0;
+
+  // Mutates this ref to point to the entry's subentry (for
+  // sharding/unsharding) or main entry (unchanged) as specified by
+  // fetch_target. The refcount is kept unchanged, since we only track the
+  // refcount of the main entry. The entry ref needs to point to the main
+  // entry before this call.
+  //
+  // If the requested subentry does not exist, the ref will point to a nullptr
+  // entry, and the original entry will be unref'ed.
+  virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
+};
+
+class TpuCompilationCacheInterface : public ResourceBase {
+ public:
+  explicit TpuCompilationCacheInterface(int64 max_cache_size);
+  ~TpuCompilationCacheInterface() override;
+
+  // Ensures there is an entry for key present in the cache. By the time
+  // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
+  // for key, and that entry will remain valid at least until
+  // per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a
+  // key that is not in the cache will evaluate compile_function to compute the
+  // value to use in the entry. Subsequent calls with the same key will block
+  // until compile_function completes. Other cache reads and inserts may proceed
+  // on other threads while compile_function is executing. If
+  // per_step_ref_holder is nullptr then the caller is responsible for calling
+  // Release(subgraph_key) to manually discard its reference to the compiled
+  // program, once the caller will not look up the compiled program again.
+  //
+  // compile_function should compile the subgraph represented by key and fill in
+  // one TPUExecutableProto per model-parallel core into its passed argument. It
+  // should return OK if and only if compilation succeeds. The executable proto
+  // vector will be discarded on non-OK status.
+  Status CompileIfKeyAbsent(
+      const TpuCompilationCacheKey& subgraph_key,
+      const SessionMetadata* session_metadata,
+      CompilationRefHolder* per_step_ref_holder, int64* uid,
+      std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+      absl::Span<const xla::HloProto* const>* hlo_metadatas,
+      const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
+
+  // Differences between MarkEntryForEviction and Release:
+  // There are two modes of managing cache entries:
+  // 1) LRU eviction + pinning; 2) manual.
+  // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
+  // Otherwise it is manual mode (mainly used by XRT).
+  // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
+  // entries when callers know that they do not need them anymore.
+  // Release should only be used in mode 2) to explicitly remove an entry.
+
+  // Mark the entry indexed by `subgraph_uid` for eviction. This should only be
+  // called if per_step_ref_holder was NOT nullptr in the corresponding call to
+  // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
+  // subgraph_uid).
+  Status MarkEntryForEviction(int64 subgraph_uid);
+
+  // Manually discards a reference to the compiled subgraph. This should only be
+  // called if per_step_ref_holder was nullptr in the corresponding call to
+  // CompileIfKeyAbsent(subgraph_key, ...).
+  Status Release(int64 subgraph_uid);
+
+  // Looks up an executable corresponding to the model-parallel core index of
+  // the subgraph represented by key. On success a pointer to an EntryRef
+  // holding the program is returned in entry.
+  template <typename CacheEntryRef, typename CacheEntryRefImpl>
+  Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
+
+  // Looks up an executable corresponding to the model-parallel core index of
+  // the subgraph represented by uid. On success a pointer to an EntryRef
+  // holding the program is returned in entry.
+  template <typename CacheEntryRef, typename CacheEntryRefImpl>
+  Status Lookup(int64 uid, int proto_index,
+                std::unique_ptr<CacheEntryRef>* entry);
+
+  // Looks up the subgraph represented by uid, and returns the vector of keys,
+  // one per core, corresponding to that subgraph.
+  Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
+
+  // Makes a reference holder for this cache, that can be stored in the per-step
+  // resource manager and will ensure that compiled entries persist until the
+  // end of a step.
+  CompilationRefHolder* MakePerStepRefHolder();
+
+  // Convenience method called by ~RefHolder without mu_ held. Calls
+  // DiscardEntryRef on every element of entries.
+  void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
+
+  string DebugString() const override { return "TpuCompilationCacheBase"; }
+
+ protected:
+  std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
+    if (!key.has_guaranteed_const) {
+      return key.prefix;
+    }
+    return absl::StrCat(key.prefix, "|", key.session_handle, "|",
+                        key.guaranteed_const_fingerprint());
+  }
+
+  // Private implementation of the generic CompilationRefHolder that knows about
+  // CompiledSubgraph entries.
+  class RefHolder : public CompilationRefHolder {
+   public:
+    explicit RefHolder(TpuCompilationCacheInterface* parent);
+    ~RefHolder() override;
+
+    // Adds entry to the list of entries that will be released when the
+    // RefHolder is destroyed. Each entry is released via a call to
+    // parent_->DiscardEntryRefs.
+    void AddRef(CompiledSubgraph* entry);
+
+    string DebugString() const override;
+
+   private:
+    TpuCompilationCacheInterface* parent_;  // Not owned.
+    std::vector<CompiledSubgraph*> entries_;
+  };
+
+  // The bulk of implementation of CompileIfKeyAbsent() with the exception
+  // of unloading programs that corresponds to possibly removed cache
+  // entries. The split helps to manage locking since we prefer to perform
+  // unloading without holding extra locks.
+  Status CompileIfKeyAbsentHelper(
+      const TpuCompilationCacheKey& subgraph_key,
+      const SessionMetadata* session_metadata,
+      CompilationRefHolder* per_step_ref_holder, int64* uid,
+      std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
+      std::vector<CompiledSubgraph*>* removed_entries,
+      absl::Span<const xla::HloProto* const>* hlo_metadatas,
+      const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
+
+  // This is called by the cache when entry is marked for eviction; by
+  // a RefHolder (via DiscardEntryRefs) when a step completes; and by
+  // an EntryRefImpl when it is destroyed. Releases one reference to entry
+  // if more than 1 remains. If only one reference is left, the entry is removed
+  // from cache_ and is returned to the caller; which must eventually call
+  // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
+  // to avoid holding the lock during program unloading.
+  ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
+      CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Marks the oldest unmarked entry for eviction. Requires that there is at
+  // least one such entry. In case the evicted entry had only 1 reference it
+  // is removed from the cache and returned to the caller which must eventually
+  // call UnloadAndDestroy.
+  ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction()
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Updates datastructures to indicate that entry, which had been marked for
+  // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
+  // entry is newly created, or an entry that has been marked for eviction but
+  // not yet evicted is looked up.
+  //
+  // First the entry is unmarked for eviction, i.e. the cache gains a reference
+  // to entry, entry's last_use field is set to be the most recent value of
+  // use_counter_ and entries_by_last_use_ is updated accordingly.
+  //
+  // Next, the size of the cache is examined to see if any other entries need to
+  // be marked for eviction now that entry has been unmarked. While the total
+  // size of unmarked cached entries is greater than max_cache_size_, entries
+  // are marked for eviction in LRU order. The most recently used entry is never
+  // marked for eviction, so an entry larger than the max cache size will remain
+  // in the cache until it is replaced by something else. In case some entries
+  // actually were removed from the cache, they are a returned to the caller via
+  // removed_entries. The caller must eventually delete them by calling
+  // UnloadAndDestroy.
+  void LookupEntryMarkedForEviction(
+      CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Removes the entry with given key from cache.
+  size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Inserts the given key and entry to cache.
+  void InsertEntry(const string& key, CompiledSubgraph* entry)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Returns the cache key matching given subgraph_key.
+  string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Creates a new entry by running initialize_programs and places it in the
+  // cache to be looked up by key. The new entry is in the 'marked for eviction'
+  // state (not present in entries_by_last_use_) and the caller is expected to
+  // call LookupEntryMarkedForEviction after InitializeEntry.
+  //
+  // **InitializeEntry releases mu_ during the call to initialize_programs.**
+  virtual CompiledSubgraph* InitializeEntry(
+      const string& key,
+      const std::function<Status(TpuProgramGroupInterface*)>&
+          initialize_programs,
+      const TpuCompilationCacheKey& subgraph_key)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
+
+  // Unloads the program associated with the entry from all local devices
+  // and deletes the entry itself. It is assumed no one else has a reference
+  // to it and all related keys had already been removed from the cache.
+  // The call can perform device IO so no locks should be held while calling it.
+  void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
+
+  // The maximum size of entries that are stored in the cache before entries are
+  // marked for eviction.
+  const int64 max_cache_size_;
+  // Mutex to protect access to shared resources under multi-threading
+  // environment.
+  absl::Mutex mu_;
+  // The total size of entries that are stored and not marked for eviction.
+  int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
+  // The total size of entries that are marked for eviction.
+  int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
+  // The value to assign to the last_use field of the next entry that is looked
+  // up.
+  int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
+  // session_key_map_ and fingerprint_key_map_ are used for looking up the
+  // cache_ key matching a given subgraph key. When doing a lookup, check
+  // session_key_map_ first to avoid unnecessay fingerprint computation.
+  // Map from key prefix + session_handle to a cache_ key.
+  absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
+  // Map from key prefix + fingerprint to a cache_ key.
+  absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
+  // All the subgraph entries that can be looked up in the cache. An entry is
+  // marked for eviction iff it is present in cache_ and not in
+  // entries_by_last_use_.
+  std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
+  // All the subgraph entries that can be looked up in the cache, indexed by
+  // uid.
+  absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
+      ABSL_GUARDED_BY(mu_);
+  // All the protos that can be looked up in the cache, indexed by proto
+  // key. The value of the map is a subgraph and the index of the proto compiled
+  // for that subgraph.
+  std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
+      entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
+  // Map from last_use to entry, used to mark entries for eviction in LRU
+  // order. If an entry's last_use counter is not present as a key in
+  // entries_by_last_use_ then the entry has been marked for eviction.
+  std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
+
+  TpuCompilationCacheMetrics tpu_compilation_cache_metrics_;
+
+ private:
+  TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
+  TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
+      delete;
+};
+
+template <typename CacheEntryRef, typename CacheEntryRefImpl>
+Status TpuCompilationCacheInterface::Lookup(
+    int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
+  entry->reset();
+
+  profiler::TraceMe proto_lookup_traceme(
+      "TPU compilation cache proto lookup by uid",
+      /*level=*/2);
+
+  absl::MutexLock lock(&mu_);
+  const auto iter = entries_by_uid_.find(uid);
+  if (iter == entries_by_uid_.end()) {
+    return errors::NotFound("No subgraph found for uid ", uid);
+  }
+  CompiledSubgraph* cache_entry = iter->second;
+  if (proto_index < 0 ||
+      proto_index >= cache_entry->tpu_program_group->program_count()) {
+    return errors::NotFound("No proto found for core index ", proto_index,
+                            " in subgraph with uid ", uid);
+  }
+  *entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
+  return Status::OK();
+}
+
+template <typename CacheEntryRef, typename CacheEntryRefImpl>
+Status TpuCompilationCacheInterface::Lookup(
+    const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
+  entry->reset();
+
+  profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
+                                         /*level=*/2);
+
+  absl::MutexLock lock(&mu_);
+  const auto iter = entries_by_proto_key_.find(proto_key);
+  if (iter == entries_by_proto_key_.end()) {
+    return errors::NotFound("No proto found for key ", proto_key);
+  }
+  CompiledSubgraph* cache_entry = iter->second.first;
+  int proto_index = iter->second.second;
+  *entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
+  return Status::OK();
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
index 8b2e832a69e..9285dff62ce 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
@@ -42,7 +42,7 @@ std::string GetName(CompilationCacheFetchTarget target) {
 }  // namespace
 
 TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
-    TpuCompilationCacheExternal* cache)
+    TpuCompilationCacheInterface* cache)
     : cache_(cache) {}
 
 TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
@@ -50,17 +50,19 @@ TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
 }
 
 Status TpuCompilationCacheLocalLookup::Lookup(
-    const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
+    const string& proto_key,
+    std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
     CompilationCacheFetchTarget fetch_target) {
   profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
                                          /*level=*/2);
-  Status s = cache_->Lookup(proto_key, entry);
+  Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
+      proto_key, entry);
   VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
           << s;
   if (!s.ok()) {
     return s;
   }
-  s = cache_->ToSubEntryRef(entry->get(), fetch_target);
+  s = (*entry)->ToSubEntryRef(fetch_target);
 
   VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
           << s;
@@ -69,17 +71,18 @@ Status TpuCompilationCacheLocalLookup::Lookup(
 
 Status TpuCompilationCacheLocalLookup::Lookup(
     int64 uid, int proto_index,
-    std::unique_ptr<CompilationCacheEntryRef>* entry,
+    std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
     CompilationCacheFetchTarget fetch_target) {
   profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
                                          /*level=*/2);
-  Status s = cache_->Lookup(uid, proto_index, entry);
+  Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
+      uid, proto_index, entry);
   VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
           << " in local subgraph cache status " << s;
   if (!s.ok()) {
     return s;
   }
-  s = cache_->ToSubEntryRef(entry->get(), fetch_target);
+  s = (*entry)->ToSubEntryRef(fetch_target);
   VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
           << s;
   return s;
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
index 0d068e1bdd1..21ca74c46a8 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
-#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
 
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
 
 namespace tensorflow {
 namespace tpu {
@@ -28,6 +30,11 @@ namespace tpu {
 // and when they need to communicate over RPC.
 class TpuCompilationCacheLookup : public ResourceBase {
  public:
+  using TpuCompilationCacheEntryRef =
+      ::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
+  using EntryRefImpl =
+      ::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
+
   ~TpuCompilationCacheLookup() override = default;
 
   // Looks up an executable corresponding to the model-parallel core index of
@@ -42,11 +49,11 @@ class TpuCompilationCacheLookup : public ResourceBase {
   // fetch_target requests one of them, then after this call
   //   (*entry)->get().get_executable() will return nullptr.
   virtual Status Lookup(const string& proto_key,
-                        std::unique_ptr<CompilationCacheEntryRef>* entry,
+                        std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
                         CompilationCacheFetchTarget fetch_target) = 0;
 
   virtual Status Lookup(const string& proto_key,
-                        std::unique_ptr<CompilationCacheEntryRef>* entry) {
+                        std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
     return Lookup(proto_key, std::move(entry),
                   CompilationCacheFetchTarget::MAIN);
   }
@@ -56,33 +63,30 @@ class TpuCompilationCacheLookup : public ResourceBase {
   // returned in program. The wrapper is guaranteed to be valid only during the
   // execution of the Op requesting the proto.
   virtual Status Lookup(int64 uid, int proto_index,
-                        std::unique_ptr<CompilationCacheEntryRef>* entry,
+                        std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
                         CompilationCacheFetchTarget fetch_target) = 0;
 
   virtual Status Lookup(int64 uid, int proto_index,
-                        std::unique_ptr<CompilationCacheEntryRef>* entry) {
+                        std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
     return Lookup(uid, proto_index, std::move(entry),
                   CompilationCacheFetchTarget::MAIN);
   }
 };
 
-// Forward declaration to break cycle dependency graph.
-class TpuCompilationCacheExternal;
-
 // Class for looking up ISA protos when the execute and compile Op are in the
 // same address space. The proto is simply looked up in the compilation cache,
 // without any serialization taking place.
 class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
  public:
-  explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheExternal* cache);
+  explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
   ~TpuCompilationCacheLocalLookup() override;
 
   Status Lookup(const string& proto_key,
-                std::unique_ptr<CompilationCacheEntryRef>* entry,
+                std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
                 CompilationCacheFetchTarget fetch_target) override;
 
   Status Lookup(int64 uid, int proto_index,
-                std::unique_ptr<CompilationCacheEntryRef>* entry,
+                std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
                 CompilationCacheFetchTarget fetch_target) override;
 
   string DebugString() const override;
@@ -90,10 +94,10 @@ class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
  private:
   // The subgraph compilation cache, in the same process address space where the
   // lookups are happening.
-  TpuCompilationCacheExternal* cache_;
+  TpuCompilationCacheInterface* cache_;
 };
 
 }  // namespace tpu
 }  // namespace tensorflow
 
-#endif  // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
index c8faba1d975..7ab1c9b8027 100644
--- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
 #include "tensorflow/core/tpu/kernels/tpu_util.h"
 #include "tensorflow/core/tpu/tpu_configuration.h"
 #include "tensorflow/core/tpu/tpu_defs.h"
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
index 0f21e458828..36f9fa96db1 100644
--- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h
@@ -24,7 +24,6 @@ limitations under the License.
 #include "absl/types/span.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/xla/client/compile_only_client.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc
index 43452b912ec..ecda2ef062e 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group.cc
+++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc
@@ -209,15 +209,8 @@ xla::HloProto TpuProgramGroup::hlo_metadata(int core_index) const {
       serialized_hlo_proto);
 }
 
-std::vector<std::shared_ptr<const xla::HloProto>>
-TpuProgramGroup::hlo_metadatas() const {
-  const size_t metadata_count = program_count();
-  std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadatas;
-  hlo_metadatas.resize(metadata_count);
-  for (size_t i = 0; i < metadata_count; ++i) {
-    hlo_metadatas[i] = std::make_shared<const xla::HloProto>(hlo_metadata(i));
-  }
-  return hlo_metadatas;
+absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
+  return absl::MakeConstSpan(hlo_metadatas_);
 }
 
 }  // namespace tpu
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h
index de8256a9e59..0ade58e6daa 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group.h
+++ b/tensorflow/core/tpu/kernels/tpu_program_group.h
@@ -139,11 +139,15 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
   const xla::HloProto& hlo_metadata() const { return hlo_metadata_; }
   void set_hlo_metadata(const xla::HloProto& hlo_metadata) {
     hlo_metadata_ = hlo_metadata;
+
+    // TODO(henrytan): initialize hlo_metadatas_ for multi program support.
+    if (hlo_metadatas_.empty()) {
+      hlo_metadatas_.push_back(&hlo_metadata_);
+    }
   }
 
   xla::HloProto hlo_metadata(int core_index) const;
-  std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadatas()
-      const override;
+  absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
 
  private:
   std::vector<bool> may_modify_variables_;
@@ -153,6 +157,7 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
   TPUExecutableInfoProto executable_info_;
   TPUHostTransferInfoProto host_transfer_info_;
   xla::HloProto hlo_metadata_;
+  std::vector<const xla::HloProto*> hlo_metadatas_;
 };
 
 }  // namespace tpu
diff --git a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h
index a4f74fb750d..8d8dd5a8786 100644
--- a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h
+++ b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h
@@ -44,9 +44,9 @@ class TpuProgramGroupInterface {
   // Logs program memory summary.
   virtual bool LogProgramMemorySummary() = 0;
 
-  // Hlo metadatas.
-  virtual std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadatas()
-      const = 0;
+  // Hlo metadatas. The pointers can only be used as long as the cache entry is
+  // referenced.
+  virtual absl::Span<const xla::HloProto* const> hlo_metadatas() const = 0;
 
   // Boolean array to indicate if the modification of variables are
   // allowed.