From 1e3eddd1523a0e9f5afc524ecdbb78e24e4a3b86 Mon Sep 17 00:00:00 2001 From: Henry Tan Date: Fri, 12 Jun 2020 13:57:03 -0700 Subject: [PATCH] TPU library internal change. PiperOrigin-RevId: 316171774 Change-Id: I85b3c2639c734e391692f568bffae0efe116e9af --- tensorflow/core/tpu/kernels/BUILD | 53 +- .../core/tpu/kernels/compiled_subgraph.h | 10 +- .../tpu/kernels/tpu_compilation_cache_entry.h | 24 +- .../tpu_compilation_cache_entry_impl.h | 108 ---- .../kernels/tpu_compilation_cache_external.cc | 566 +++++++++++++++++- .../kernels/tpu_compilation_cache_external.h | 268 ++++++++- .../kernels/tpu_compilation_cache_interface.h | 355 ----------- .../kernels/tpu_compilation_cache_lookup.cc | 17 +- .../kernels/tpu_compilation_cache_lookup.h | 32 +- .../core/tpu/kernels/tpu_compile_op_common.cc | 1 - .../core/tpu/kernels/tpu_compile_op_support.h | 1 + 11 files changed, 848 insertions(+), 587 deletions(-) delete mode 100644 tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h delete mode 100644 tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index ee1cc6f0908..e7be7d2b062 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -19,9 +19,9 @@ cc_library( deps = [ ":tpu_compile_op_support", ":tpu_mesh_state_interface", - ":tpu_program_group_interface", ":tpu_util", ":tpu_util_hdrs", + "@com_google_absl//absl/types:span", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -30,16 +30,16 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + # "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", - "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/stream_executor/tpu:tpu_platform_interface", - "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -157,28 +157,14 @@ cc_library( "tpu_compilation_cache_entry.h", ], deps = [ - ":compiled_subgraph", - ":tpu_compilation_cache_proto_cc", ":tpu_executable_info_proto_cc", ":tpu_program_group", "//tensorflow/compiler/xla/service:hlo_proto_cc", - "//tensorflow/core:framework", "//tensorflow/core/lib/core:refcount", "//tensorflow/core/platform:casts", ], ) -cc_library( - name = "tpu_compilation_cache_entry_impl", - srcs = [], - hdrs = ["tpu_compilation_cache_entry_impl.h"], - deps = [ - ":compiled_subgraph", - ":tpu_compilation_cache_interface", - ":tpu_executable_info_proto_cc", - ], -) - cc_library( name = "tpu_compilation_cache_lookup", srcs = ["tpu_compilation_cache_lookup.cc"], @@ -188,7 +174,6 @@ cc_library( deps = [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_external", - ":tpu_compilation_cache_interface", ":tpu_compilation_cache_proto_cc", "//tensorflow/core/lib/core:refcount", "//tensorflow/core/platform:status", @@ -262,35 +247,6 @@ cc_library( ], ) -cc_library( - name = "tpu_compilation_cache_interface", - srcs = ["tpu_compilation_cache_interface.cc"], - hdrs = ["tpu_compilation_cache_interface.h"], - deps = [ - ":compiled_subgraph", - ":tpu_compilation_cache_key", - ":tpu_compilation_cache_metrics_hdrs", - ":tpu_compilation_cache_proto_cc", - ":tpu_util", - ":tpu_util_hdrs", - ":trace_util_hdrs", - "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:grpc_call", - "//tensorflow/core/platform:casts", # buildcleaner: keep - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - ], - alwayslink = 1, -) - cc_library( name = "tpu_compilation_cache_external", srcs = ["tpu_compilation_cache_external.cc"], @@ -300,8 +256,6 @@ cc_library( deps = [ ":compiled_subgraph", ":tpu_compilation_cache_entry", - ":tpu_compilation_cache_entry_impl", - ":tpu_compilation_cache_interface", ":tpu_compilation_cache_key", ":tpu_compilation_cache_metrics", # buildcleaner: keep ":tpu_compilation_cache_metrics_hdrs", @@ -401,7 +355,6 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/core/tpu/kernels/compiled_subgraph.h b/tensorflow/core/tpu/kernels/compiled_subgraph.h index a97c652c279..1066e4839dd 100644 --- a/tensorflow/core/tpu/kernels/compiled_subgraph.h +++ b/tensorflow/core/tpu/kernels/compiled_subgraph.h @@ -25,9 +25,6 @@ limitations under the License. namespace tensorflow { namespace tpu { -// Forward declaration to avoid circular dependency. -class TpuCompilationCacheInterface; - // Cache for compiled TPU program. // // Each key identifies a unique subgraph, and the value is the vector of @@ -103,7 +100,10 @@ class TpuCompilationCacheInterface; // unmarked and set to most recently used. // struct CompiledSubgraph : public core::RefCounted { - TpuCompilationCacheInterface* parent = nullptr; // Not owned. + // TODO(henrytan): once `TpuCompilationCache` and + // `TpuCompilationCacheExternal` inherits from `TpuCompilationCacheInterface` + // update void* with `TpuCompilationCacheInterface` + void* parent = nullptr; // Not owned. bool initialized = false; @@ -145,7 +145,7 @@ struct CompiledSubgraph : public core::RefCounted { // owning main entry. CompiledSubgraph* main_entry = nullptr; - // Compiled TPU program group. + // Compiled Tpu program. std::unique_ptr tpu_program_group; // Computes total program size. diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h index b3766b8b4dd..a561fc51778 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace tpu { -// A version of `CompilationCacheEntry` to access Tpu binary program +// A version of `CompilationCacheEntry` that exposes Tpu binary program // `XLA_TpuProgram`. class TpuCompilationCacheEntry { public: @@ -42,6 +42,28 @@ class TpuCompilationCacheEntry { int core_index_; }; +// Base class for a reference to a cached proto. A unique_ptr to a +// CompilationCacheEntryRef is returned by all the cache Lookup methods below, +// and ensures the underlying proto is not garbage-collected until the client +// discards the ptr. +class CompilationCacheEntryRef { + public: + virtual ~CompilationCacheEntryRef() = default; + + // Returns a CompilationCacheEntry that should not be used beyond the lifetime + // of the CompilationCacheEntryRef. + virtual TpuCompilationCacheEntry get() = 0; +}; + +// Base class that holds references to compiled protos so that the protos are +// not garbage-collected before being used by execute ops. Use +// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete +// ref holder object. +class CompilationRefHolder : public ResourceBase { + public: + ~CompilationRefHolder() override = default; +}; + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h deleted file mode 100644 index 501f802b01f..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ - -#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" - -namespace tensorflow { -namespace tpu { - -// Wrapper for a cache entry that holds a reference to the entry until the -// wrapper is deleted. This wrapper is the concrete type of -// CompilationCacheEntryRef returned by Lookup. -template -class CompilationCacheEntryRefImpl - : public CompilationCacheEntryRef { - public: - CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent, - CompiledSubgraph* entry, int index); - - ~CompilationCacheEntryRefImpl() override; - - Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override; - - protected: - TpuCompilationCacheInterface* parent_; // Not owned. - // A reference to entry_ is acquired in the constructor and released via - // parent->DiscardEntryRefs in the destructor. - CompiledSubgraph* entry_; - // The index of the program in entry_ that is returned by the get method. - int index_; -}; - -template -CompilationCacheEntryRefImpl::CompilationCacheEntryRefImpl( - TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) - : parent_(parent), entry_(entry), index_(index) { - if (entry_ == nullptr) { - return; - } - if (entry_->main_entry == nullptr) { - entry_->Ref(); - } else { - // This is a sharding/unsharding entry nested in a main entry. Only - // refcount the main entry. - entry_->main_entry->Ref(); - } -} - -template -CompilationCacheEntryRefImpl::~CompilationCacheEntryRefImpl() { - if (entry_ == nullptr) { - return; - } - if (entry_->main_entry == nullptr) { - parent_->DiscardEntryRefs({entry_}); - } else { - parent_->DiscardEntryRefs({entry_->main_entry}); - } -} - -template -Status CompilationCacheEntryRefImpl::ToSubEntryRef( - CompilationCacheFetchTarget fetch_target) { - CompiledSubgraph* target = nullptr; - switch (fetch_target) { - case CompilationCacheFetchTarget::MAIN: - target = entry_; - break; - case CompilationCacheFetchTarget::SHARDING: - target = entry_->sharding_entry.get(); - break; - case CompilationCacheFetchTarget::UNSHARDING: - target = entry_->unsharding_entry.get(); - break; - default: - return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); - } - - if (target == nullptr) { - // Cache entry does not have an unsharding subentry. Unref and replace - // with nullptr. - parent_->DiscardEntryRefs({entry_}); - } - // Otherwise, since the refcount is always on the main entry, we don't - // need ref/unref. - entry_ = target; - return Status::OK(); -} - -} // namespace tpu -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc index 8cee90e8e55..614dfbdf577 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc @@ -50,6 +50,14 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry, entry->initialized = true; } +std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) { + if (!key.has_guaranteed_const) { + return key.prefix; + } + return absl::StrCat(key.prefix, "|", key.session_handle, "|", + key.guaranteed_const_fingerprint()); +} + // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor // data to compute the fingerprint. std::string GuaranteedConstFingerprint( @@ -115,32 +123,85 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) { } // namespace -TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl( - TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) - : CompilationCacheEntryRefImpl(parent, entry, - index) {} - -TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() { - if (entry_ == nullptr) { - // Create an empty entry if the entry is nullptr. This corresponds to - // non-existing sharding/unsharding entries. - return TpuCompilationCacheEntry(); +TpuCompilationCacheExternal::TpuCompilationCacheExternal(int64_t max_cache_size) + : max_cache_size_(max_cache_size) { + if (max_cache_size < 0) { + LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0"; } - return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_); + VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes."; +} + +TpuCompilationCacheExternal::~TpuCompilationCacheExternal() { + VLOG(1) << "TpuCompilationCacheExternal::~TpuCompilationCacheExternal()"; + // A buggy client may be holding onto a reference, or a client might have + // crashed while holding onto a reference. In either case, discard all + // outstanding client references to avoid leaking storage. + for (const auto& entry : entries_by_uid_) { + while (entry.second->external_references > 0) { + TF_CHECK_OK(Release(entry.first)); + } + } + while (!entries_by_last_use_.empty()) { + UnloadAndDestroy(MarkOldestEntryForEviction()); + } + // By the time the cache is deleted all reference holders should have already + // been deleted, since they were holding references to the cache. So all + // entries should be gone at this point. + CHECK_EQ(cache_store_.size(), 0); + CHECK_EQ(entries_by_uid_.size(), 0); + CHECK_EQ(entries_by_proto_key_.size(), 0); + CHECK_EQ(cache_size_, 0); + CHECK_EQ(marked_for_eviction_size_, 0); +} + +std::string TpuCompilationCacheExternal::FindCacheKey( + const TpuCompilationCacheKey& subgraph_key) const { + if (!subgraph_key.has_guaranteed_const) { + return subgraph_key.prefix; + } + auto iter = session_key_map_.find( + strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle)); + if (iter != session_key_map_.end()) { + return iter->second; + } + iter = fingerprint_key_map_.find(strings::StrCat( + subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint())); + if (iter != session_key_map_.end()) { + return iter->second; + } + VLOG(1) << "No matching cache key found for key " + << ConstructCompilationCacheKey(subgraph_key); + return ""; +} + +void TpuCompilationCacheExternal::InsertEntry( + const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key, + CompiledSubgraph* entry) { + entry->parent = this; + entry->subgraph_key = cache_key; + entry->uid = get_uid(); + TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size()); + entry->cache_entry_debug_string = subgraph_key.prefix; + VLOG(1) << "Cache Initializing Entry Session Debug " + << entry->cache_entry_debug_string; + + if (!subgraph_key.has_guaranteed_const) { + return; + } + session_key_map_.insert(std::make_pair( + strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle), + cache_key)); + fingerprint_key_map_.insert(std::make_pair( + strings::StrCat(subgraph_key.prefix, + subgraph_key.guaranteed_const_fingerprint()), + cache_key)); } CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( const string& key, - const std::function& initialize_program, + const std::function& initialize_program, const TpuCompilationCacheKey& subgraph_key) { CompiledSubgraph* main_entry = new CompiledSubgraph(); - main_entry->parent = this; - main_entry->subgraph_key = key; - main_entry->uid = get_uid(); - // TODO(henrytan): implement TpuCompilationCacheKey.debug_string. - main_entry->cache_entry_debug_string = subgraph_key.prefix; - VLOG(1) << "Cache Initializing Entry Session Debug " - << main_entry->cache_entry_debug_string; // Add the entry to the cache, with size zero since there are no compiled // programs in it. Once the subgraph has been compiled, @@ -151,7 +212,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( // who created the entry. A second reference, owned by the cache, will be // added below since we leave the entry in the 'marked for eviction' state // here. - InsertEntry(key, main_entry); + InsertEntry(key, subgraph_key, main_entry); // Initialize the programs outside the lock so that other cache operations // can proceed during the (potentially lengthy) initialization. @@ -259,5 +320,470 @@ TpuCompilationCacheExternal::CreateCompilationCacheKey( } return key; } + +TpuCompilationRefHolder* TpuCompilationCacheExternal::MakePerStepRefHolder() { + return new RefHolder(this); +} + +Status TpuCompilationCacheExternal::MarkEntryForEviction(int64 subgraph_uid) { + profiler::TraceMe key_release_traceme( + "TPU compilation cache possibly evict uid", + /*level=*/2); + CompiledSubgraph* deleted_entry = nullptr; + { + absl::MutexLock lock(&mu_); + auto iter = entries_by_uid_.find(subgraph_uid); + if (iter == entries_by_uid_.end()) { + // If already evicted, return ok. + return Status::OK(); + } + + // Mark entry for eviction. + CompiledSubgraph* subgraph_to_evict = iter->second; + // If there are external references, should not use this API. + if (subgraph_to_evict->external_references != 0) { + return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key, + " external_references greater than zero. Should " + "use TpuCompilationCache::Release."); + } + + VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction"; + entries_by_last_use_.erase(subgraph_to_evict->last_use); + cache_size_ -= subgraph_to_evict->total_size; + marked_for_eviction_size_ += subgraph_to_evict->total_size; + + // Evict if refcount exactly one, otherwise only discard cache's reference + // to the entry while the actual eviction will happen when refholder's + // references go away. + deleted_entry = DiscardEntryRef(subgraph_to_evict); + + VLOG(1) << "After possibly evicting entry " << subgraph_uid + << " refs cache is " << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ + << " bytes), marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + + // Unload from device cache if entry is evicted from host cache. + UnloadAndDestroy(deleted_entry); + return Status::OK(); +} + +Status TpuCompilationCacheExternal::Release(int64 subgraph_uid) { + profiler::TraceMe key_release_traceme("TPU compilation cache release uid", + /*level=*/2); + + CompiledSubgraph* deleted_entry = nullptr; + { + absl::MutexLock lock(&mu_); + auto iter = entries_by_uid_.find(subgraph_uid); + + if (iter == entries_by_uid_.end()) { + return errors::NotFound("No cache entry found for uid ", subgraph_uid); + } + + CHECK_GT(iter->second->external_references, 0); + --iter->second->external_references; + + deleted_entry = DiscardEntryRef(iter->second); + + VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is " + << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ + << " bytes), marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + UnloadAndDestroy(deleted_entry); + return Status::OK(); +} + +void TpuCompilationCacheExternal::UnloadAndDestroy(CompiledSubgraph* entry) { + if (!entry) return; + + CHECK(entry->RefCountIsOne()); + entry->tpu_program_group->UnloadAndDestroyPrograms(); + entry->Unref(); +} + +size_t TpuCompilationCacheExternal::RemoveEntry(const string& key) { + auto erased = cache_store_.erase(key); + TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size()); + auto parsed_key_or_status = ParseCompilationCacheKey(key); + CHECK(parsed_key_or_status.status().ok()); + const TpuCompilationCacheKey parsed_key = + parsed_key_or_status.ConsumeValueOrDie(); + if (!parsed_key.has_guaranteed_const) { + return erased; + } + session_key_map_.erase( + strings::StrCat(parsed_key.prefix, parsed_key.session_handle)); + fingerprint_key_map_.erase(strings::StrCat( + parsed_key.prefix, parsed_key.guaranteed_const_fingerprint())); + return erased; +} + +ABSL_MUST_USE_RESULT CompiledSubgraph* +TpuCompilationCacheExternal::DiscardEntryRef(CompiledSubgraph* entry) { + if (entry->RefCountIsOne()) { + // The last reference to this entry is going away, so really delete it from + // the cache in such a way that it can't be restored by being looked up + // again. + + // Sanity-check that it has been marked for eviction. + CHECK(entries_by_last_use_.find(entry->last_use) == + entries_by_last_use_.end()); + // Update the counter tracking how much space is taken up by entries that + // are marked for eviction. + marked_for_eviction_size_ -= entry->total_size; + + // Remove the entry from the cache. + auto erased = RemoveEntry(entry->subgraph_key); + + if (erased == 0) { + LOG(FATAL) << "Tried to discard nonexistent cache entry"; + } + erased = entries_by_uid_.erase(entry->uid); + CHECK_EQ(erased, 1); + for (const string& key : entry->proto_key) { + erased = entries_by_proto_key_.erase(key); + CHECK_EQ(erased, 1); + } + // The actual deletion will happen outside the lock in UnloadAndDestroy(). + return entry; + } + entry->Unref(); + return nullptr; +} + +void TpuCompilationCacheExternal::DiscardEntryRefs( + gtl::ArraySlice entries) { + std::vector removed_entries; + { + absl::MutexLock lock(&mu_); + + for (auto entry : entries) { + removed_entries.push_back(DiscardEntryRef(entry)); + } + + VLOG(1) << "After discarding entry refs cache is " << cache_store_.size() + << " entries (" << cache_size_ + marked_for_eviction_size_ + << " bytes), marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + for (auto removed_entry : removed_entries) { + UnloadAndDestroy(removed_entry); + } +} + +ABSL_MUST_USE_RESULT CompiledSubgraph* +TpuCompilationCacheExternal::MarkOldestEntryForEviction() { + CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second; + VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction"; + entries_by_last_use_.erase(entry_to_mark->last_use); + cache_size_ -= entry_to_mark->total_size; + marked_for_eviction_size_ += entry_to_mark->total_size; + // Discard the cache's reference to entry. If steps are holding onto + // references to entry it won't be deleted until the last step holding it + // completes. It stays in the cache in the meantime and can be resurrected + // by a call to CompileIfKeyAbsent if that occurs before the last reference + // expires. + return DiscardEntryRef(entry_to_mark); +} + +void TpuCompilationCacheExternal::LookupEntryMarkedForEviction( + CompiledSubgraph* entry, std::vector* removed_entries) { + // The entry was previously marked for eviction (or is newly created) so + // unmark it. Add a reference (owned by the cache), update the cache size, and + // mark something old for eviction if necessary. + entry->Ref(); + marked_for_eviction_size_ -= entry->total_size; + cache_size_ += entry->total_size; + + // Mark the least-recently-used non-marked entry for eviction. Never mark the + // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1 + // which means there's only one entry not already marked for eviction), so + // that an entry persists in the cache even if it is larger than the allocated + // cache size. + while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) { + if (auto entry_to_evict = MarkOldestEntryForEviction()) { + removed_entries->push_back(entry_to_evict); + } + } +} + +Status TpuCompilationCacheExternal::ToSubEntryRef( + CompilationCacheEntryRef* entry, + CompilationCacheFetchTarget fetch_target) const { + return static_cast(entry)->ToSubEntryRef(fetch_target); +} + +TpuCompilationCacheExternal::TpuEntryRefImpl::TpuEntryRefImpl( + TpuCompilationCacheExternal* parent, CompiledSubgraph* entry, int index) + : parent_(parent), entry_(entry), index_(index) { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + entry_->Ref(); + } else { + // This is a sharding/unsharding entry nested in a main entry. Only refcount + // the main entry. + entry_->main_entry->Ref(); + } +} + +TpuCompilationCacheExternal::TpuEntryRefImpl::~TpuEntryRefImpl() { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + parent_->DiscardEntryRefs({entry_}); + } else { + parent_->DiscardEntryRefs({entry_->main_entry}); + } +} + +TpuCompilationCacheEntry TpuCompilationCacheExternal::TpuEntryRefImpl::get() { + if (entry_ == nullptr) { + // Create an empty entry if the entry is nullptr. This corresponds to + // non-existing sharding/unsharding entries. + return TpuCompilationCacheEntry(); + } + return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_); +} + +Status TpuCompilationCacheExternal::TpuEntryRefImpl::ToSubEntryRef( + CompilationCacheFetchTarget fetch_target) { + CompiledSubgraph* target = nullptr; + switch (fetch_target) { + case CompilationCacheFetchTarget::MAIN: + target = entry_; + break; + case CompilationCacheFetchTarget::SHARDING: + target = entry_->sharding_entry.get(); + break; + case CompilationCacheFetchTarget::UNSHARDING: + target = entry_->unsharding_entry.get(); + break; + default: + return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); + } + + if (target == nullptr) { + // Cache entry does not have an unsharding subentry. Unref and replace + // with nullptr. + parent_->DiscardEntryRefs({entry_}); + } + // Otherwise, since the refcount is always on the main entry, we don't need + // ref/unref. + entry_ = target; + return Status::OK(); +} + +Status TpuCompilationCacheExternal::Lookup( + int64 uid, int proto_index, + std::unique_ptr* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme( + "TPU compilation cache proto lookup by uid", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_uid_.find(uid); + if (iter == entries_by_uid_.end()) { + return errors::NotFound("No subgraph found for uid ", uid); + } + CompiledSubgraph* cache_entry = iter->second; + if (proto_index < 0 || + proto_index >= cache_entry->tpu_program_group->program_size()) { + return errors::NotFound("No proto found for core index ", proto_index, + " in subgraph with uid ", uid); + } + *entry = std::unique_ptr( + new TpuEntryRefImpl(this, cache_entry, proto_index)); + return Status::OK(); +} + +Status TpuCompilationCacheExternal::Lookup( + const string& proto_key, std::unique_ptr* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_proto_key_.find(proto_key); + if (iter == entries_by_proto_key_.end()) { + return errors::NotFound("No proto found for key ", proto_key); + } + CompiledSubgraph* cache_entry = iter->second.first; + int proto_index = iter->second.second; + *entry = std::unique_ptr( + new TpuEntryRefImpl(this, cache_entry, proto_index)); + return Status::OK(); +} + +Status TpuCompilationCacheExternal::CompileIfKeyAbsentHelper( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector* removed_entries, + std::vector>* hlo_metadata, + const std::function& compile_function) { + profiler::TraceMe subgraph_lookup_traceme( + "TPU compilation cache subgraph lookup", + /*level=*/2); + + // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock + // for the lifetime of the object, see InitializeEntry() call below. + absl::MutexLock lock(&mu_); + + std::string cache_key = FindCacheKey(subgraph_key); + auto iter = cache_store_.find(cache_key); + bool is_new_key = iter == cache_store_.end(); + + const std::string session_name = SessionNameFromMetadata(session_metadata); + + CompiledSubgraph* entry = nullptr; + if (is_new_key) { + cache_key = ConstructCompilationCacheKey(subgraph_key); + TpuCompilationCacheMetrics::IncrementCacheLookupCount( + /*is_cache_hit=*/false, session_name); + const string msg = + strings::StrCat("TPU host compilation cache miss: cache_key(", + cache_key, "), session_name(", session_name, ")"); + + TRACESTRING(msg); + LOG(INFO) << msg; + + // Check if caller has disabled compilation. Set using + // internal::ScopedTpuCompileDisabler. + if (!IsTpuCompilationEnabled()) { + const string error_msg = strings::StrCat( + "[TpuCompilationDisabled]: Compilation cache miss, but compilation " + "disabled, session_name(", + session_name, ") Debug String: ", subgraph_key.debug_string); + if (VLOG_IS_ON(2)) { + VLOG(2) << "Cache Missed. Current cache entries: "; + for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) { + // TODO(henrytan): add DebugKey as cache_entry_debug_string to + // TpuCompilationCacheKey. + VLOG(2) << "Cache Debug Info: "; + VLOG(2) << it->second->cache_entry_debug_string; + } + } + + LOG_EVERY_N_SEC(WARNING, 30) << error_msg; + return errors::NotFound(error_msg); + } + + // The single ref on the newly-created entry is owned by the caller. + VLOG(1) << "Before adding new entry for key " << cache_key + << " with session_name( " << session_name << ");" + << "; cache is " << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ << " bytes), " + << " marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + // Note that InitializeEntry() will Release/Reacquire mu_. + entry = InitializeEntry(cache_key, compile_function, subgraph_key); + TRACELITERAL("TPU host compilation cache: compilation done."); + + LOG(INFO) << strings::StrCat( + "TPU host compilation cache: compilation done for cache_key(", + cache_key, "), session_name(", session_name, ")"); + // If session_name is present, log some additional stats related to HBM + // here, so that they can be associated directly to the session. + if (!session_name.empty()) { + entry->tpu_program_group->LogProgramMemorySummary(); + } + } else { + TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name); + const string msg = + strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key, + "), session_name(", session_name, ")"); + TRACESTRING(msg); + VLOG(1) << msg; + VLOG(1) << "Before refreshing entry for key " << cache_key + << " with session_name( " << session_name << "); cache is " + << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ << " bytes), " + << " marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + entry = iter->second; + // Make a new reference that is owned by the caller. + entry->Ref(); + // Block if necessary until the subgraph has been initialized. + mu_.Await(absl::Condition( + +[](CompiledSubgraph* e) { return e->initialized; }, entry)); + } + + // Let the caller know the uid of the entry. + *uid = entry->uid; + // Let the caller know the keys for each of the cached protos. + *proto_key = entry->proto_key; + *may_modify_variables = entry->tpu_program_group->may_modify_variables(); + *hlo_metadata = entry->tpu_program_group->hlo_metadatas(); + + // If the caller didn't supply a per_step_ref_holder then the caller is going + // to manually release the reference later via a call to Release(). + if (per_step_ref_holder == nullptr) { + ++entry->external_references; + } else { + // The caller wants its reference to be handed off to a per-step holder that + // will discard the reference when the step completes. + RefHolder* cast_ref_holder = static_cast(per_step_ref_holder); + TF_RET_CHECK(cast_ref_holder != nullptr); + cast_ref_holder->AddRef(entry); + } + + // Remove the old LRU-table entry if it wasn't already marked for eviction. + auto erased = entries_by_last_use_.erase(entry->last_use); + // Update the LRU table indicating this entry is the most recently used. + entry->last_use = use_counter_++; + entries_by_last_use_[entry->last_use] = entry; + if (erased == 0) { + // The entry had been marked for eviction, or is newly created. + LookupEntryMarkedForEviction(entry, removed_entries); + } + + // Log a little more verbosely when a key is added. + if (VLOG_IS_ON(1) || is_new_key) { + LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing") + << " entry for key " << cache_key << " with session_name " + << session_name << " cache is " << cache_store_.size() + << " entries (" << cache_size_ + marked_for_eviction_size_ + << " bytes), " + << " marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + return entry->initialization_status; +} + +tensorflow::Status TpuCompilationCacheExternal::CompileIfKeyAbsent( + const TpuCompilationCacheKey& cache_key, + const tensorflow::SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector>* hlo_metadata, + const std::function& + compile_function) { + std::vector removed_entries; + auto status = CompileIfKeyAbsentHelper( + cache_key, session_metadata, per_step_ref_holder, uid, proto_key, + may_modify_variables, &removed_entries, hlo_metadata, compile_function); + for (auto entry : removed_entries) { + UnloadAndDestroy(entry); + } + return status; +} + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h index 2c75cb4d053..eff2afde108 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -26,14 +26,11 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" @@ -43,25 +40,37 @@ limitations under the License. namespace tensorflow { namespace tpu { -constexpr char kCompilationCacheResourceName[] = "tpu_compilation_cache"; -constexpr char kCompilationCacheUnloaderResourceName[] = +const char kCompilationCacheResourceName[] = "tpu_compilation_cache"; +const char kCompilationCacheUnloaderResourceName[] = "tpu_compilation_cache_unloader"; -class TpuCompilationCacheExternal : public TpuCompilationCacheInterface { +// Base class that holds references to compiled protos so that the protos are +// not garbage-collected before being used by execute ops. Use +// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete +// ref holder object. +class TpuCompilationRefHolder : public ResourceBase { + public: + ~TpuCompilationRefHolder() override = default; +}; + +class TpuCompilationCacheExternal : public ResourceBase { public: using Status = ::stream_executor::port::Status; - class EntryRefImpl - : public CompilationCacheEntryRefImpl { - public: - EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, - int index); + explicit TpuCompilationCacheExternal(int64_t max_cache_size); + ~TpuCompilationCacheExternal() override; + TpuCompilationCacheExternal(const TpuCompilationCacheExternal&) = delete; + TpuCompilationCacheExternal& operator=(const TpuCompilationCacheExternal&) = + delete; - TpuCompilationCacheEntry get() override; - }; - - explicit TpuCompilationCacheExternal(int64 max_cache_size) - : TpuCompilationCacheInterface(max_cache_size) {} + Status CompileIfKeyAbsent( + const TpuCompilationCacheKey& cache_key, + const SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector>* hlo_metadata, + const std::function& + compile_function); static TpuCompilationCacheKey CreateCompilationCacheKey( absl::string_view function_name, uint64 function_library_fingerprint, @@ -73,7 +82,177 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface { string DebugString() const override { return "TpuCompilationCacheExternal"; } + // Makes a reference holder for this cache, that can be stored in the per-step + // resource manager and will ensure that compiled entries persist until the + // end of a step. + TpuCompilationRefHolder* MakePerStepRefHolder(); + + // Differences between MarkEntryForEviction and Release: + // There are two modes of managing cache entries: + // 1) LRU eviction + pinning; 2) manual. + // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent. + // Otherwise it is manual mode (mainly used by XRT). + // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache + // entries when callers know that they do not need them anymore. + // Release should only be used in mode 2) to explicitly remove an entry. + + // Mark the entry indexed by `subgraph_uid` for eviction. This should only be + // called if per_step_ref_holder was NOT nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64 + // subgraph_uid). + Status MarkEntryForEviction(int64 subgraph_uid); + + // Manually discards a reference to the compiled subgraph. This should only be + // called if per_step_ref_holder was nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). + Status Release(int64 subgraph_uid); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by key. On success a pointer to an EntryRef + // holding the program is returned in entry. + Status Lookup(const string& proto_key, + std::unique_ptr* entry); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by uid. On success a pointer to an EntryRef + // holding the program is returned in entry. + Status Lookup(int64 uid, int proto_index, + std::unique_ptr* entry); + + // Mutates the main entry ref to point to the entry's subentry + // (for sharding/unsharding) or main entry (unchanged) representing the + // fetch target. The entry ref needs to point to the main entry before this + // call. + // + // If the requested subentry does not exist, the ref will point to a nullptr + // entry. + Status ToSubEntryRef(CompilationCacheEntryRef* entry, + CompilationCacheFetchTarget fetch_target) const; + private: + // Wrapper for a cache entry that holds a reference to the entry until the + // wrapper is deleted. This wrapper is the concrete type of + // CompilationCacheEntryRef returned by Lookup. + class TpuEntryRefImpl : public CompilationCacheEntryRef { + public: + TpuEntryRefImpl(TpuCompilationCacheExternal* parent, + CompiledSubgraph* entry, int index); + ~TpuEntryRefImpl() override; + + TpuCompilationCacheEntry get() override; + + // Mutates this ref to point to the entry's subentry (for + // sharding/unsharding) or main entry (unchanged) as specified by + // fetch_target. The refcount is kept unchanged, since we only track the + // refcount of the main entry. The entry ref needs to point to the main + // entry before this call. + // + // If the requested subentry does not exist, the ref will point to a nullptr + // entry, and the original entry will be unref'ed. + Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target); + + private: + TpuCompilationCacheExternal* parent_; // Not owned. + // A reference to entry_ is acquired in the constructor and released via + // parent->DiscardEntryRefs in the destructor. + CompiledSubgraph* entry_; + // The program in entry_ that is returned by the get method. + int index_; + }; + + // Private implementation of the generic CompilationRefHolder that knows about + // CompiledSubgraph entries. + class RefHolder : public TpuCompilationRefHolder { + public: + explicit RefHolder(TpuCompilationCacheExternal* parent) : parent_(parent) { + parent_->Ref(); + } + ~RefHolder() override { + // Release our reference to the parent. + parent_->Unref(); + } + + // Adds entry to the list of entries that will be released when the + // RefHolder is destroyed. Each entry is released via a call to + // parent_->DiscardEntryRefs. + void AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); } + + string DebugString() const override { + return "TpuCompilationCacheExternal::RefHolder"; + } + + private: + TpuCompilationCacheExternal* parent_; // Not owned. + std::vector entries_; + }; + + // The bulk of implementation of CompileIfKeyAbsent() with the exception + // of unloading programs that corresponds to possibly removed cache + // entries. The split helps to manage locking since we prefer to perform + // unloading without holding extra locks. + Status CompileIfKeyAbsentHelper( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector* removed_entries, + std::vector>* hlo_metadata, + const std::function& compile_function); + + // This is called by the cache when entry is marked for eviction; by + // a RefHolder (via DiscardEntryRefs) when a step completes; and by + // an EntryRefImpl when it is destroyed. Releases one reference to entry + // if more than 1 remains. If only one reference is left, the entry is removed + // from cache_ and is returned to the caller; which must eventually call + // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef + // to avoid holding the lock during program unloading. + ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef( + CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Convenience method called by ~RefHolder without mu_ held. Calls + // DiscardEntryRef on every element of entries. + void DiscardEntryRefs(gtl::ArraySlice entries); + + // Marks the oldest unmarked entry for eviction. Requires that there is at + // least one such entry. In case the evicted entry had only 1 reference it + // is removed from the cache and returned to the caller which must eventually + // call UnloadAndDestroy. + CompiledSubgraph* MarkOldestEntryForEviction() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Updates datastructures to indicate that entry, which had been marked for + // eviction, has been looked up. This is called by CompileIfKeyAbsent when an + // entry is newly created, or an entry that has been marked for eviction but + // not yet evicted is looked up. + // + // First the entry is unmarked for eviction, i.e. the cache gains a reference + // to entry, entry's last_use field is set to be the most recent value of + // use_counter_ and entries_by_last_use_ is updated accordingly. + // + // Next, the size of the cache is examined to see if any other entries need to + // be marked for eviction now that entry has been unmarked. While the total + // size of unmarked cached entries is greater than max_cache_size_, entries + // are marked for eviction in LRU order. The most recently used entry is never + // marked for eviction, so an entry larger than the max cache size will remain + // in the cache until it is replaced by something else. In case some entries + // actually were removed from the cache, they are a returned to the caller via + // removed_entries. The caller must eventually delete them by calling + // UnloadAndDestroy. + void LookupEntryMarkedForEviction( + CompiledSubgraph* entry, std::vector* removed_entries) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Removes the entry with given key from cache. + size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Inserts the given key and entry to cache. + void InsertEntry(const std::string& key, + const TpuCompilationCacheKey& subgraph_key, + CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the cache key matching given subgraph_key. + std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates a new entry by running initialize_programs and places it in the // cache to be looked up by key. The new entry is in the 'marked for eviction' // state (not present in entries_by_last_use_) and the caller is expected to @@ -82,10 +261,61 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface { // **InitializeEntry releases mu_ during the call to initialize_programs.** CompiledSubgraph* InitializeEntry( const string& key, - const std::function& - initialize_program, + const std::function& initialize_program, const TpuCompilationCacheKey& subgraph_key) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override; + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Unloads the program associated with the entry from all local devices + // and deletes the entry itself. It is assumed no one else has a reference + // to it and all related keys had already been removed from the cache. + // The call can perform device IO so no locks should be held while calling it. + void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_); + + // The maximum size of entries that are stored in the cache before entries are + // marked for eviction. + const int64 max_cache_size_; + + mutable absl::Mutex mu_; + // The total size of entries that are stored and not marked for eviction. + int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0; + + // The total size of entries that are marked for eviction. + int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0; + + // The value to assign to the last_use field of the next entry that is looked + // up. + int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0; + + // session_key_map_ and fingerprint_key_map_ are used for looking up the + // cache_ key matching a given subgraph key. When doing a lookup, check + // session_key_map_ first to avoid unnecessay fingerprint computation. + // Map from key prefix + session_handle to a cache_ key. + std::unordered_map session_key_map_ ABSL_GUARDED_BY(mu_); + + // Map from key prefix + fingerprint to a cache_ key. + std::unordered_map fingerprint_key_map_ ABSL_GUARDED_BY(mu_); + + // All the subgraph entries that can be looked up in the cache. An entry is + // marked for eviction iff it is present in cache_ and not in + // entries_by_last_use_. + std::unordered_map cache_store_ + ABSL_GUARDED_BY(mu_); + + // All the subgraph entries that can be looked up in the cache, indexed by + // uid. + absl::node_hash_map entries_by_uid_ + ABSL_GUARDED_BY(mu_); + + // All the protos that can be looked up in the cache, indexed by proto + // key. The value of the map is a subgraph and the index of the proto compiled + // for that subgraph. + std::unordered_map> + entries_by_proto_key_ ABSL_GUARDED_BY(mu_); + + // Map from last_use to entry, used to mark entries for eviction in LRU + // order. If an entry's last_use counter is not present as a key in + // entries_by_last_use_ then the entry has been marked for eviction. + std::map entries_by_last_use_ ABSL_GUARDED_BY(mu_); }; } // namespace tpu diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h deleted file mode 100644 index 8d98a265f35..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h +++ /dev/null @@ -1,355 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ - -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/node_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h" -#include "tensorflow/core/tpu/kernels/trace_util.h" - -namespace tensorflow { -namespace tpu { - -// Base class that holds references to compiled protos so that the protos are -// not garbage-collected before being used by execute ops. Use -// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete -// ref holder object. -class CompilationRefHolder : public ResourceBase { - public: - ~CompilationRefHolder() override = default; -}; - -// Base class for a reference to a cached tpu program. A unique_ptr to a -// CompilationCacheEntryRef is returned by all the cache Lookup methods below, -// and ensures the underlying proto is not garbage-collected until the client -// discards the ptr. -template -class CompilationCacheEntryRef { - public: - virtual ~CompilationCacheEntryRef() = default; - - // Returns a CompilationCacheEntry that should not be used beyond the lifetime - // of the tpu::CompilationCacheEntryRef. - virtual CacheEntryType get() = 0; - - // Mutates this ref to point to the entry's subentry (for - // sharding/unsharding) or main entry (unchanged) as specified by - // fetch_target. The refcount is kept unchanged, since we only track the - // refcount of the main entry. The entry ref needs to point to the main - // entry before this call. - // - // If the requested subentry does not exist, the ref will point to a nullptr - // entry, and the original entry will be unref'ed. - virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0; -}; - -class TpuCompilationCacheInterface : public ResourceBase { - public: - explicit TpuCompilationCacheInterface(int64 max_cache_size); - ~TpuCompilationCacheInterface() override; - - // Ensures there is an entry for key present in the cache. By the time - // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache - // for key, and that entry will remain valid at least until - // per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a - // key that is not in the cache will evaluate compile_function to compute the - // value to use in the entry. Subsequent calls with the same key will block - // until compile_function completes. Other cache reads and inserts may proceed - // on other threads while compile_function is executing. If - // per_step_ref_holder is nullptr then the caller is responsible for calling - // Release(subgraph_key) to manually discard its reference to the compiled - // program, once the caller will not look up the compiled program again. - // - // compile_function should compile the subgraph represented by key and fill in - // one TPUExecutableProto per model-parallel core into its passed argument. It - // should return OK if and only if compilation succeeds. The executable proto - // vector will be discarded on non-OK status. - Status CompileIfKeyAbsent( - const TpuCompilationCacheKey& subgraph_key, - const SessionMetadata* session_metadata, - CompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector* proto_key, std::vector* may_modify_variables, - std::vector>* hlo_metadatas, - const std::function& compile_function); - - // Differences between MarkEntryForEviction and Release: - // There are two modes of managing cache entries: - // 1) LRU eviction + pinning; 2) manual. - // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent. - // Otherwise it is manual mode (mainly used by XRT). - // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache - // entries when callers know that they do not need them anymore. - // Release should only be used in mode 2) to explicitly remove an entry. - - // Mark the entry indexed by `subgraph_uid` for eviction. This should only be - // called if per_step_ref_holder was NOT nullptr in the corresponding call to - // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64 - // subgraph_uid). - Status MarkEntryForEviction(int64 subgraph_uid); - - // Manually discards a reference to the compiled subgraph. This should only be - // called if per_step_ref_holder was nullptr in the corresponding call to - // CompileIfKeyAbsent(subgraph_key, ...). - Status Release(int64 subgraph_uid); - - // Looks up an executable corresponding to the model-parallel core index of - // the subgraph represented by key. On success a pointer to an EntryRef - // holding the program is returned in entry. - template - Status Lookup(const string& proto_key, std::unique_ptr* entry); - - // Looks up an executable corresponding to the model-parallel core index of - // the subgraph represented by uid. On success a pointer to an EntryRef - // holding the program is returned in entry. - template - Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry); - - // Looks up the subgraph represented by uid, and returns the vector of keys, - // one per core, corresponding to that subgraph. - Status GetKeysFromUid(int64 uid, std::vector* keys); - - // Makes a reference holder for this cache, that can be stored in the per-step - // resource manager and will ensure that compiled entries persist until the - // end of a step. - CompilationRefHolder* MakePerStepRefHolder(); - - // Convenience method called by ~RefHolder without mu_ held. Calls - // DiscardEntryRef on every element of entries. - void DiscardEntryRefs(gtl::ArraySlice entries); - - string DebugString() const override { return "TpuCompilationCacheBase"; } - - protected: - std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) { - if (!key.has_guaranteed_const) { - return key.prefix; - } - return absl::StrCat(key.prefix, "|", key.session_handle, "|", - key.guaranteed_const_fingerprint()); - } - - // Private implementation of the generic CompilationRefHolder that knows about - // CompiledSubgraph entries. - class RefHolder : public CompilationRefHolder { - public: - explicit RefHolder(TpuCompilationCacheInterface* parent); - ~RefHolder() override; - - // Adds entry to the list of entries that will be released when the - // RefHolder is destroyed. Each entry is released via a call to - // parent_->DiscardEntryRefs. - void AddRef(CompiledSubgraph* entry); - - string DebugString() const override; - - private: - TpuCompilationCacheInterface* parent_; // Not owned. - std::vector entries_; - }; - - // The bulk of implementation of CompileIfKeyAbsent() with the exception - // of unloading programs that corresponds to possibly removed cache - // entries. The split helps to manage locking since we prefer to perform - // unloading without holding extra locks. - Status CompileIfKeyAbsentHelper( - const TpuCompilationCacheKey& subgraph_key, - const SessionMetadata* session_metadata, - CompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector* proto_key, std::vector* may_modify_variables, - std::vector* removed_entries, - std::vector>* hlo_metadatas, - const std::function& compile_function); - - // This is called by the cache when entry is marked for eviction; by - // a RefHolder (via DiscardEntryRefs) when a step completes; and by - // an EntryRefImpl when it is destroyed. Releases one reference to entry - // if more than 1 remains. If only one reference is left, the entry is removed - // from cache_ and is returned to the caller; which must eventually call - // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef - // to avoid holding the lock during program unloading. - ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef( - CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Marks the oldest unmarked entry for eviction. Requires that there is at - // least one such entry. In case the evicted entry had only 1 reference it - // is removed from the cache and returned to the caller which must eventually - // call UnloadAndDestroy. - ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Updates datastructures to indicate that entry, which had been marked for - // eviction, has been looked up. This is called by CompileIfKeyAbsent when an - // entry is newly created, or an entry that has been marked for eviction but - // not yet evicted is looked up. - // - // First the entry is unmarked for eviction, i.e. the cache gains a reference - // to entry, entry's last_use field is set to be the most recent value of - // use_counter_ and entries_by_last_use_ is updated accordingly. - // - // Next, the size of the cache is examined to see if any other entries need to - // be marked for eviction now that entry has been unmarked. While the total - // size of unmarked cached entries is greater than max_cache_size_, entries - // are marked for eviction in LRU order. The most recently used entry is never - // marked for eviction, so an entry larger than the max cache size will remain - // in the cache until it is replaced by something else. In case some entries - // actually were removed from the cache, they are a returned to the caller via - // removed_entries. The caller must eventually delete them by calling - // UnloadAndDestroy. - void LookupEntryMarkedForEviction( - CompiledSubgraph* entry, std::vector* removed_entries) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Removes the entry with given key from cache. - size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Inserts the given key and entry to cache. - void InsertEntry(const string& key, CompiledSubgraph* entry) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Returns the cache key matching given subgraph_key. - string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Creates a new entry by running initialize_programs and places it in the - // cache to be looked up by key. The new entry is in the 'marked for eviction' - // state (not present in entries_by_last_use_) and the caller is expected to - // call LookupEntryMarkedForEviction after InitializeEntry. - // - // **InitializeEntry releases mu_ during the call to initialize_programs.** - virtual CompiledSubgraph* InitializeEntry( - const string& key, - const std::function& - initialize_programs, - const TpuCompilationCacheKey& subgraph_key) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; - - // Unloads the program associated with the entry from all local devices - // and deletes the entry itself. It is assumed no one else has a reference - // to it and all related keys had already been removed from the cache. - // The call can perform device IO so no locks should be held while calling it. - void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_); - - // The maximum size of entries that are stored in the cache before entries are - // marked for eviction. - const int64 max_cache_size_; - // Mutex to protect access to shared resources under multi-threading - // environment. - absl::Mutex mu_; - // The total size of entries that are stored and not marked for eviction. - int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0; - // The total size of entries that are marked for eviction. - int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0; - // The value to assign to the last_use field of the next entry that is looked - // up. - int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0; - // session_key_map_ and fingerprint_key_map_ are used for looking up the - // cache_ key matching a given subgraph key. When doing a lookup, check - // session_key_map_ first to avoid unnecessay fingerprint computation. - // Map from key prefix + session_handle to a cache_ key. - absl::node_hash_map session_key_map_ ABSL_GUARDED_BY(mu_); - // Map from key prefix + fingerprint to a cache_ key. - absl::node_hash_map fingerprint_key_map_ ABSL_GUARDED_BY(mu_); - // All the subgraph entries that can be looked up in the cache. An entry is - // marked for eviction iff it is present in cache_ and not in - // entries_by_last_use_. - std::unordered_map cache_ ABSL_GUARDED_BY(mu_); - // All the subgraph entries that can be looked up in the cache, indexed by - // uid. - absl::node_hash_map entries_by_uid_ - ABSL_GUARDED_BY(mu_); - // All the protos that can be looked up in the cache, indexed by proto - // key. The value of the map is a subgraph and the index of the proto compiled - // for that subgraph. - std::unordered_map> - entries_by_proto_key_ ABSL_GUARDED_BY(mu_); - // Map from last_use to entry, used to mark entries for eviction in LRU - // order. If an entry's last_use counter is not present as a key in - // entries_by_last_use_ then the entry has been marked for eviction. - std::map entries_by_last_use_ ABSL_GUARDED_BY(mu_); - - TpuCompilationCacheMetrics tpu_compilation_cache_metrics_; - - private: - TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete; - TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) = - delete; -}; - -template -Status TpuCompilationCacheInterface::Lookup( - int64 uid, int proto_index, std::unique_ptr* entry) { - entry->reset(); - - profiler::TraceMe proto_lookup_traceme( - "TPU compilation cache proto lookup by uid", - /*level=*/2); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_uid_.find(uid); - if (iter == entries_by_uid_.end()) { - return errors::NotFound("No subgraph found for uid ", uid); - } - CompiledSubgraph* cache_entry = iter->second; - if (proto_index < 0 || - proto_index >= cache_entry->tpu_program_group->program_count()) { - return errors::NotFound("No proto found for core index ", proto_index, - " in subgraph with uid ", uid); - } - *entry = absl::make_unique(this, cache_entry, proto_index); - return Status::OK(); -} - -template -Status TpuCompilationCacheInterface::Lookup( - const string& proto_key, std::unique_ptr* entry) { - entry->reset(); - - profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", - /*level=*/2); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_proto_key_.find(proto_key); - if (iter == entries_by_proto_key_.end()) { - return errors::NotFound("No proto found for key ", proto_key); - } - CompiledSubgraph* cache_entry = iter->second.first; - int proto_index = iter->second.second; - *entry = absl::make_unique(this, cache_entry, proto_index); - return Status::OK(); -} - -} // namespace tpu -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc index 9285dff62ce..8b2e832a69e 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc @@ -42,7 +42,7 @@ std::string GetName(CompilationCacheFetchTarget target) { } // namespace TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup( - TpuCompilationCacheInterface* cache) + TpuCompilationCacheExternal* cache) : cache_(cache) {} TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() { @@ -50,19 +50,17 @@ TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() { } Status TpuCompilationCacheLocalLookup::Lookup( - const string& proto_key, - std::unique_ptr* entry, + const string& proto_key, std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) { profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup", /*level=*/2); - Status s = cache_->Lookup( - proto_key, entry); + Status s = cache_->Lookup(proto_key, entry); VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status " << s; if (!s.ok()) { return s; } - s = (*entry)->ToSubEntryRef(fetch_target); + s = cache_->ToSubEntryRef(entry->get(), fetch_target); VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " << s; @@ -71,18 +69,17 @@ Status TpuCompilationCacheLocalLookup::Lookup( Status TpuCompilationCacheLocalLookup::Lookup( int64 uid, int proto_index, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) { profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid", /*level=*/2); - Status s = cache_->Lookup( - uid, proto_index, entry); + Status s = cache_->Lookup(uid, proto_index, entry); VLOG(1) << "Looked up uid " << uid << ", index " << proto_index << " in local subgraph cache status " << s; if (!s.ok()) { return s; } - s = (*entry)->ToSubEntryRef(fetch_target); + s = cache_->ToSubEntryRef(entry->get(), fetch_target); VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " << s; return s; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h index 21ca74c46a8..0d068e1bdd1 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h @@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ +#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ +#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" namespace tensorflow { namespace tpu { @@ -30,11 +28,6 @@ namespace tpu { // and when they need to communicate over RPC. class TpuCompilationCacheLookup : public ResourceBase { public: - using TpuCompilationCacheEntryRef = - ::tensorflow::tpu::CompilationCacheEntryRef; - using EntryRefImpl = - ::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl; - ~TpuCompilationCacheLookup() override = default; // Looks up an executable corresponding to the model-parallel core index of @@ -49,11 +42,11 @@ class TpuCompilationCacheLookup : public ResourceBase { // fetch_target requests one of them, then after this call // (*entry)->get().get_executable() will return nullptr. virtual Status Lookup(const string& proto_key, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) = 0; virtual Status Lookup(const string& proto_key, - std::unique_ptr* entry) { + std::unique_ptr* entry) { return Lookup(proto_key, std::move(entry), CompilationCacheFetchTarget::MAIN); } @@ -63,30 +56,33 @@ class TpuCompilationCacheLookup : public ResourceBase { // returned in program. The wrapper is guaranteed to be valid only during the // execution of the Op requesting the proto. virtual Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) = 0; virtual Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry) { + std::unique_ptr* entry) { return Lookup(uid, proto_index, std::move(entry), CompilationCacheFetchTarget::MAIN); } }; +// Forward declaration to break cycle dependency graph. +class TpuCompilationCacheExternal; + // Class for looking up ISA protos when the execute and compile Op are in the // same address space. The proto is simply looked up in the compilation cache, // without any serialization taking place. class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { public: - explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache); + explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheExternal* cache); ~TpuCompilationCacheLocalLookup() override; Status Lookup(const string& proto_key, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) override; Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) override; string DebugString() const override; @@ -94,10 +90,10 @@ class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { private: // The subgraph compilation cache, in the same process address space where the // lookups are happening. - TpuCompilationCacheInterface* cache_; + TpuCompilationCacheExternal* cache_; }; } // namespace tpu } // namespace tensorflow -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ +#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index a44c255df60..ae090913dc7 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_defs.h" diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index 36f9fa96db1..0f21e458828 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h"