From 899471f58fe71ca36e5f2669ea766acc5c47039d Mon Sep 17 00:00:00 2001 From: Henry Tan <henrytan@google.com> Date: Wed, 17 Jun 2020 22:06:11 -0700 Subject: [PATCH] TPU library internal update. PiperOrigin-RevId: 317033127 Change-Id: I06b440435c20e83126cf28c97f11ff1a3b18c61e --- tensorflow/core/tpu/kernels/BUILD | 53 +- .../core/tpu/kernels/compiled_subgraph.h | 10 +- .../kernels/tpu_compilation_cache_entry.cc | 2 +- .../tpu/kernels/tpu_compilation_cache_entry.h | 24 +- .../tpu_compilation_cache_entry_impl.h | 108 ++++ .../kernels/tpu_compilation_cache_external.cc | 564 +----------------- .../kernels/tpu_compilation_cache_external.h | 268 +-------- .../tpu_compilation_cache_interface.cc | 15 +- .../kernels/tpu_compilation_cache_interface.h | 355 +++++++++++ .../kernels/tpu_compilation_cache_lookup.cc | 17 +- .../kernels/tpu_compilation_cache_lookup.h | 32 +- .../core/tpu/kernels/tpu_compile_op_common.cc | 1 + .../core/tpu/kernels/tpu_compile_op_support.h | 1 - .../core/tpu/kernels/tpu_program_group.cc | 11 +- .../core/tpu/kernels/tpu_program_group.h | 9 +- .../tpu/kernels/tpu_program_group_interface.h | 6 +- 16 files changed, 609 insertions(+), 867 deletions(-) create mode 100644 tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h create mode 100644 tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 318d60b22df..9ba9ad61aa0 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -19,9 +19,9 @@ cc_library( deps = [ ":tpu_compile_op_support", ":tpu_mesh_state_interface", + ":tpu_program_group_interface", ":tpu_util", ":tpu_util_hdrs", - "@com_google_absl//absl/types:span", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -30,16 +30,16 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:compile_only_client", - "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - # "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -157,14 +157,28 @@ cc_library( "tpu_compilation_cache_entry.h", ], deps = [ + ":compiled_subgraph", + ":tpu_compilation_cache_proto_cc", ":tpu_executable_info_proto_cc", ":tpu_program_group", "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:framework", "//tensorflow/core/lib/core:refcount", "//tensorflow/core/platform:casts", ], ) +cc_library( + name = "tpu_compilation_cache_entry_impl", + srcs = [], + hdrs = ["tpu_compilation_cache_entry_impl.h"], + deps = [ + ":compiled_subgraph", + ":tpu_compilation_cache_interface", + ":tpu_executable_info_proto_cc", + ], +) + cc_library( name = "tpu_compilation_cache_lookup", srcs = ["tpu_compilation_cache_lookup.cc"], @@ -174,6 +188,7 @@ cc_library( deps = [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_external", + ":tpu_compilation_cache_interface", ":tpu_compilation_cache_proto_cc", "//tensorflow/core/lib/core:refcount", "//tensorflow/core/platform:status", @@ -247,6 +262,35 @@ cc_library( ], ) +cc_library( + name = "tpu_compilation_cache_interface", + srcs = ["tpu_compilation_cache_interface.cc"], + hdrs = ["tpu_compilation_cache_interface.h"], + deps = [ + ":compiled_subgraph", + ":tpu_compilation_cache_key", + ":tpu_compilation_cache_metrics_hdrs", + ":tpu_compilation_cache_proto_cc", + ":tpu_util", + ":tpu_util_hdrs", + ":trace_util_hdrs", + "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime/rpc:grpc_call", + "//tensorflow/core/platform:casts", # buildcleaner: keep + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + cc_library( name = "tpu_compilation_cache_external", srcs = ["tpu_compilation_cache_external.cc"], @@ -256,6 +300,8 @@ cc_library( deps = [ ":compiled_subgraph", ":tpu_compilation_cache_entry", + ":tpu_compilation_cache_entry_impl", + ":tpu_compilation_cache_interface", ":tpu_compilation_cache_key", ":tpu_compilation_cache_metrics", # buildcleaner: keep ":tpu_compilation_cache_metrics_hdrs", @@ -355,6 +401,7 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/core/tpu/kernels/compiled_subgraph.h b/tensorflow/core/tpu/kernels/compiled_subgraph.h index 1066e4839dd..a97c652c279 100644 --- a/tensorflow/core/tpu/kernels/compiled_subgraph.h +++ b/tensorflow/core/tpu/kernels/compiled_subgraph.h @@ -25,6 +25,9 @@ limitations under the License. namespace tensorflow { namespace tpu { +// Forward declaration to avoid circular dependency. +class TpuCompilationCacheInterface; + // Cache for compiled TPU program. // // Each key identifies a unique subgraph, and the value is the vector of @@ -100,10 +103,7 @@ namespace tpu { // unmarked and set to most recently used. // struct CompiledSubgraph : public core::RefCounted { - // TODO(henrytan): once `TpuCompilationCache` and - // `TpuCompilationCacheExternal` inherits from `TpuCompilationCacheInterface` - // update void* with `TpuCompilationCacheInterface` - void* parent = nullptr; // Not owned. + TpuCompilationCacheInterface* parent = nullptr; // Not owned. bool initialized = false; @@ -145,7 +145,7 @@ struct CompiledSubgraph : public core::RefCounted { // owning main entry. CompiledSubgraph* main_entry = nullptr; - // Compiled Tpu program. + // Compiled TPU program group. std::unique_ptr<TpuProgramGroupInterface> tpu_program_group; // Computes total program size. diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc index 4d1f306ec0c..73f55853306 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc @@ -40,7 +40,7 @@ TpuCompilationCacheEntry::get_host_transfer_info() const { } const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const { - return tpu_program_group_->hlo_metadatas()[core_index_].get(); + return tpu_program_group_->hlo_metadatas()[core_index_]; } // TODO(henrytan,jiawenhao): When should we expect more than one diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h index a561fc51778..b3766b8b4dd 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace tpu { -// A version of `CompilationCacheEntry` that exposes Tpu binary program +// A version of `CompilationCacheEntry` to access Tpu binary program // `XLA_TpuProgram`. class TpuCompilationCacheEntry { public: @@ -42,28 +42,6 @@ class TpuCompilationCacheEntry { int core_index_; }; -// Base class for a reference to a cached proto. A unique_ptr to a -// CompilationCacheEntryRef is returned by all the cache Lookup methods below, -// and ensures the underlying proto is not garbage-collected until the client -// discards the ptr. -class CompilationCacheEntryRef { - public: - virtual ~CompilationCacheEntryRef() = default; - - // Returns a CompilationCacheEntry that should not be used beyond the lifetime - // of the CompilationCacheEntryRef. - virtual TpuCompilationCacheEntry get() = 0; -}; - -// Base class that holds references to compiled protos so that the protos are -// not garbage-collected before being used by execute ops. Use -// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete -// ref holder object. -class CompilationRefHolder : public ResourceBase { - public: - ~CompilationRefHolder() override = default; -}; - } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h new file mode 100644 index 00000000000..501f802b01f --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ + +#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" + +namespace tensorflow { +namespace tpu { + +// Wrapper for a cache entry that holds a reference to the entry until the +// wrapper is deleted. This wrapper is the concrete type of +// CompilationCacheEntryRef returned by Lookup. +template <typename CacheEntryType> +class CompilationCacheEntryRefImpl + : public CompilationCacheEntryRef<CacheEntryType> { + public: + CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent, + CompiledSubgraph* entry, int index); + + ~CompilationCacheEntryRefImpl() override; + + Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override; + + protected: + TpuCompilationCacheInterface* parent_; // Not owned. + // A reference to entry_ is acquired in the constructor and released via + // parent->DiscardEntryRefs in the destructor. + CompiledSubgraph* entry_; + // The index of the program in entry_ that is returned by the get method. + int index_; +}; + +template <typename CacheEntryType> +CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl( + TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) + : parent_(parent), entry_(entry), index_(index) { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + entry_->Ref(); + } else { + // This is a sharding/unsharding entry nested in a main entry. Only + // refcount the main entry. + entry_->main_entry->Ref(); + } +} + +template <typename CacheEntryType> +CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + parent_->DiscardEntryRefs({entry_}); + } else { + parent_->DiscardEntryRefs({entry_->main_entry}); + } +} + +template <typename CacheEntryType> +Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef( + CompilationCacheFetchTarget fetch_target) { + CompiledSubgraph* target = nullptr; + switch (fetch_target) { + case CompilationCacheFetchTarget::MAIN: + target = entry_; + break; + case CompilationCacheFetchTarget::SHARDING: + target = entry_->sharding_entry.get(); + break; + case CompilationCacheFetchTarget::UNSHARDING: + target = entry_->unsharding_entry.get(); + break; + default: + return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); + } + + if (target == nullptr) { + // Cache entry does not have an unsharding subentry. Unref and replace + // with nullptr. + parent_->DiscardEntryRefs({entry_}); + } + // Otherwise, since the refcount is always on the main entry, we don't + // need ref/unref. + entry_ = target; + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc index 614dfbdf577..8cee90e8e55 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc @@ -50,14 +50,6 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry, entry->initialized = true; } -std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) { - if (!key.has_guaranteed_const) { - return key.prefix; - } - return absl::StrCat(key.prefix, "|", key.session_handle, "|", - key.guaranteed_const_fingerprint()); -} - // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor // data to compute the fingerprint. std::string GuaranteedConstFingerprint( @@ -123,85 +115,32 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) { } // namespace -TpuCompilationCacheExternal::TpuCompilationCacheExternal(int64_t max_cache_size) - : max_cache_size_(max_cache_size) { - if (max_cache_size < 0) { - LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0"; - } - VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes."; -} +TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl( + TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) + : CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry, + index) {} -TpuCompilationCacheExternal::~TpuCompilationCacheExternal() { - VLOG(1) << "TpuCompilationCacheExternal::~TpuCompilationCacheExternal()"; - // A buggy client may be holding onto a reference, or a client might have - // crashed while holding onto a reference. In either case, discard all - // outstanding client references to avoid leaking storage. - for (const auto& entry : entries_by_uid_) { - while (entry.second->external_references > 0) { - TF_CHECK_OK(Release(entry.first)); - } +TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() { + if (entry_ == nullptr) { + // Create an empty entry if the entry is nullptr. This corresponds to + // non-existing sharding/unsharding entries. + return TpuCompilationCacheEntry(); } - while (!entries_by_last_use_.empty()) { - UnloadAndDestroy(MarkOldestEntryForEviction()); - } - // By the time the cache is deleted all reference holders should have already - // been deleted, since they were holding references to the cache. So all - // entries should be gone at this point. - CHECK_EQ(cache_store_.size(), 0); - CHECK_EQ(entries_by_uid_.size(), 0); - CHECK_EQ(entries_by_proto_key_.size(), 0); - CHECK_EQ(cache_size_, 0); - CHECK_EQ(marked_for_eviction_size_, 0); -} - -std::string TpuCompilationCacheExternal::FindCacheKey( - const TpuCompilationCacheKey& subgraph_key) const { - if (!subgraph_key.has_guaranteed_const) { - return subgraph_key.prefix; - } - auto iter = session_key_map_.find( - strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle)); - if (iter != session_key_map_.end()) { - return iter->second; - } - iter = fingerprint_key_map_.find(strings::StrCat( - subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint())); - if (iter != session_key_map_.end()) { - return iter->second; - } - VLOG(1) << "No matching cache key found for key " - << ConstructCompilationCacheKey(subgraph_key); - return ""; -} - -void TpuCompilationCacheExternal::InsertEntry( - const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key, - CompiledSubgraph* entry) { - entry->parent = this; - entry->subgraph_key = cache_key; - entry->uid = get_uid(); - TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size()); - entry->cache_entry_debug_string = subgraph_key.prefix; - VLOG(1) << "Cache Initializing Entry Session Debug " - << entry->cache_entry_debug_string; - - if (!subgraph_key.has_guaranteed_const) { - return; - } - session_key_map_.insert(std::make_pair( - strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle), - cache_key)); - fingerprint_key_map_.insert(std::make_pair( - strings::StrCat(subgraph_key.prefix, - subgraph_key.guaranteed_const_fingerprint()), - cache_key)); + return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_); } CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( const string& key, - const std::function<Status(TpuProgramGroup*)>& initialize_program, + const std::function<Status(TpuProgramGroupInterface*)>& initialize_program, const TpuCompilationCacheKey& subgraph_key) { CompiledSubgraph* main_entry = new CompiledSubgraph(); + main_entry->parent = this; + main_entry->subgraph_key = key; + main_entry->uid = get_uid(); + // TODO(henrytan): implement TpuCompilationCacheKey.debug_string. + main_entry->cache_entry_debug_string = subgraph_key.prefix; + VLOG(1) << "Cache Initializing Entry Session Debug " + << main_entry->cache_entry_debug_string; // Add the entry to the cache, with size zero since there are no compiled // programs in it. Once the subgraph has been compiled, @@ -212,7 +151,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( // who created the entry. A second reference, owned by the cache, will be // added below since we leave the entry in the 'marked for eviction' state // here. - InsertEntry(key, subgraph_key, main_entry); + InsertEntry(key, main_entry); // Initialize the programs outside the lock so that other cache operations // can proceed during the (potentially lengthy) initialization. @@ -320,470 +259,5 @@ TpuCompilationCacheExternal::CreateCompilationCacheKey( } return key; } - -TpuCompilationRefHolder* TpuCompilationCacheExternal::MakePerStepRefHolder() { - return new RefHolder(this); -} - -Status TpuCompilationCacheExternal::MarkEntryForEviction(int64 subgraph_uid) { - profiler::TraceMe key_release_traceme( - "TPU compilation cache possibly evict uid", - /*level=*/2); - CompiledSubgraph* deleted_entry = nullptr; - { - absl::MutexLock lock(&mu_); - auto iter = entries_by_uid_.find(subgraph_uid); - if (iter == entries_by_uid_.end()) { - // If already evicted, return ok. - return Status::OK(); - } - - // Mark entry for eviction. - CompiledSubgraph* subgraph_to_evict = iter->second; - // If there are external references, should not use this API. - if (subgraph_to_evict->external_references != 0) { - return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key, - " external_references greater than zero. Should " - "use TpuCompilationCache::Release."); - } - - VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction"; - entries_by_last_use_.erase(subgraph_to_evict->last_use); - cache_size_ -= subgraph_to_evict->total_size; - marked_for_eviction_size_ += subgraph_to_evict->total_size; - - // Evict if refcount exactly one, otherwise only discard cache's reference - // to the entry while the actual eviction will happen when refholder's - // references go away. - deleted_entry = DiscardEntryRef(subgraph_to_evict); - - VLOG(1) << "After possibly evicting entry " << subgraph_uid - << " refs cache is " << cache_store_.size() << " entries (" - << cache_size_ + marked_for_eviction_size_ - << " bytes), marked for eviction " - << (cache_store_.size() - entries_by_last_use_.size()) - << " entries (" << marked_for_eviction_size_ << " bytes)."; - } - - // Unload from device cache if entry is evicted from host cache. - UnloadAndDestroy(deleted_entry); - return Status::OK(); -} - -Status TpuCompilationCacheExternal::Release(int64 subgraph_uid) { - profiler::TraceMe key_release_traceme("TPU compilation cache release uid", - /*level=*/2); - - CompiledSubgraph* deleted_entry = nullptr; - { - absl::MutexLock lock(&mu_); - auto iter = entries_by_uid_.find(subgraph_uid); - - if (iter == entries_by_uid_.end()) { - return errors::NotFound("No cache entry found for uid ", subgraph_uid); - } - - CHECK_GT(iter->second->external_references, 0); - --iter->second->external_references; - - deleted_entry = DiscardEntryRef(iter->second); - - VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is " - << cache_store_.size() << " entries (" - << cache_size_ + marked_for_eviction_size_ - << " bytes), marked for eviction " - << (cache_store_.size() - entries_by_last_use_.size()) - << " entries (" << marked_for_eviction_size_ << " bytes)."; - } - UnloadAndDestroy(deleted_entry); - return Status::OK(); -} - -void TpuCompilationCacheExternal::UnloadAndDestroy(CompiledSubgraph* entry) { - if (!entry) return; - - CHECK(entry->RefCountIsOne()); - entry->tpu_program_group->UnloadAndDestroyPrograms(); - entry->Unref(); -} - -size_t TpuCompilationCacheExternal::RemoveEntry(const string& key) { - auto erased = cache_store_.erase(key); - TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size()); - auto parsed_key_or_status = ParseCompilationCacheKey(key); - CHECK(parsed_key_or_status.status().ok()); - const TpuCompilationCacheKey parsed_key = - parsed_key_or_status.ConsumeValueOrDie(); - if (!parsed_key.has_guaranteed_const) { - return erased; - } - session_key_map_.erase( - strings::StrCat(parsed_key.prefix, parsed_key.session_handle)); - fingerprint_key_map_.erase(strings::StrCat( - parsed_key.prefix, parsed_key.guaranteed_const_fingerprint())); - return erased; -} - -ABSL_MUST_USE_RESULT CompiledSubgraph* -TpuCompilationCacheExternal::DiscardEntryRef(CompiledSubgraph* entry) { - if (entry->RefCountIsOne()) { - // The last reference to this entry is going away, so really delete it from - // the cache in such a way that it can't be restored by being looked up - // again. - - // Sanity-check that it has been marked for eviction. - CHECK(entries_by_last_use_.find(entry->last_use) == - entries_by_last_use_.end()); - // Update the counter tracking how much space is taken up by entries that - // are marked for eviction. - marked_for_eviction_size_ -= entry->total_size; - - // Remove the entry from the cache. - auto erased = RemoveEntry(entry->subgraph_key); - - if (erased == 0) { - LOG(FATAL) << "Tried to discard nonexistent cache entry"; - } - erased = entries_by_uid_.erase(entry->uid); - CHECK_EQ(erased, 1); - for (const string& key : entry->proto_key) { - erased = entries_by_proto_key_.erase(key); - CHECK_EQ(erased, 1); - } - // The actual deletion will happen outside the lock in UnloadAndDestroy(). - return entry; - } - entry->Unref(); - return nullptr; -} - -void TpuCompilationCacheExternal::DiscardEntryRefs( - gtl::ArraySlice<CompiledSubgraph*> entries) { - std::vector<CompiledSubgraph*> removed_entries; - { - absl::MutexLock lock(&mu_); - - for (auto entry : entries) { - removed_entries.push_back(DiscardEntryRef(entry)); - } - - VLOG(1) << "After discarding entry refs cache is " << cache_store_.size() - << " entries (" << cache_size_ + marked_for_eviction_size_ - << " bytes), marked for eviction " - << (cache_store_.size() - entries_by_last_use_.size()) - << " entries (" << marked_for_eviction_size_ << " bytes)."; - } - for (auto removed_entry : removed_entries) { - UnloadAndDestroy(removed_entry); - } -} - -ABSL_MUST_USE_RESULT CompiledSubgraph* -TpuCompilationCacheExternal::MarkOldestEntryForEviction() { - CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second; - VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction"; - entries_by_last_use_.erase(entry_to_mark->last_use); - cache_size_ -= entry_to_mark->total_size; - marked_for_eviction_size_ += entry_to_mark->total_size; - // Discard the cache's reference to entry. If steps are holding onto - // references to entry it won't be deleted until the last step holding it - // completes. It stays in the cache in the meantime and can be resurrected - // by a call to CompileIfKeyAbsent if that occurs before the last reference - // expires. - return DiscardEntryRef(entry_to_mark); -} - -void TpuCompilationCacheExternal::LookupEntryMarkedForEviction( - CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) { - // The entry was previously marked for eviction (or is newly created) so - // unmark it. Add a reference (owned by the cache), update the cache size, and - // mark something old for eviction if necessary. - entry->Ref(); - marked_for_eviction_size_ -= entry->total_size; - cache_size_ += entry->total_size; - - // Mark the least-recently-used non-marked entry for eviction. Never mark the - // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1 - // which means there's only one entry not already marked for eviction), so - // that an entry persists in the cache even if it is larger than the allocated - // cache size. - while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) { - if (auto entry_to_evict = MarkOldestEntryForEviction()) { - removed_entries->push_back(entry_to_evict); - } - } -} - -Status TpuCompilationCacheExternal::ToSubEntryRef( - CompilationCacheEntryRef* entry, - CompilationCacheFetchTarget fetch_target) const { - return static_cast<TpuEntryRefImpl*>(entry)->ToSubEntryRef(fetch_target); -} - -TpuCompilationCacheExternal::TpuEntryRefImpl::TpuEntryRefImpl( - TpuCompilationCacheExternal* parent, CompiledSubgraph* entry, int index) - : parent_(parent), entry_(entry), index_(index) { - if (entry_ == nullptr) { - return; - } - if (entry_->main_entry == nullptr) { - entry_->Ref(); - } else { - // This is a sharding/unsharding entry nested in a main entry. Only refcount - // the main entry. - entry_->main_entry->Ref(); - } -} - -TpuCompilationCacheExternal::TpuEntryRefImpl::~TpuEntryRefImpl() { - if (entry_ == nullptr) { - return; - } - if (entry_->main_entry == nullptr) { - parent_->DiscardEntryRefs({entry_}); - } else { - parent_->DiscardEntryRefs({entry_->main_entry}); - } -} - -TpuCompilationCacheEntry TpuCompilationCacheExternal::TpuEntryRefImpl::get() { - if (entry_ == nullptr) { - // Create an empty entry if the entry is nullptr. This corresponds to - // non-existing sharding/unsharding entries. - return TpuCompilationCacheEntry(); - } - return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_); -} - -Status TpuCompilationCacheExternal::TpuEntryRefImpl::ToSubEntryRef( - CompilationCacheFetchTarget fetch_target) { - CompiledSubgraph* target = nullptr; - switch (fetch_target) { - case CompilationCacheFetchTarget::MAIN: - target = entry_; - break; - case CompilationCacheFetchTarget::SHARDING: - target = entry_->sharding_entry.get(); - break; - case CompilationCacheFetchTarget::UNSHARDING: - target = entry_->unsharding_entry.get(); - break; - default: - return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); - } - - if (target == nullptr) { - // Cache entry does not have an unsharding subentry. Unref and replace - // with nullptr. - parent_->DiscardEntryRefs({entry_}); - } - // Otherwise, since the refcount is always on the main entry, we don't need - // ref/unref. - entry_ = target; - return Status::OK(); -} - -Status TpuCompilationCacheExternal::Lookup( - int64 uid, int proto_index, - std::unique_ptr<CompilationCacheEntryRef>* entry) { - entry->reset(); - - profiler::TraceMe proto_lookup_traceme( - "TPU compilation cache proto lookup by uid", - /*level=*/2); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_uid_.find(uid); - if (iter == entries_by_uid_.end()) { - return errors::NotFound("No subgraph found for uid ", uid); - } - CompiledSubgraph* cache_entry = iter->second; - if (proto_index < 0 || - proto_index >= cache_entry->tpu_program_group->program_size()) { - return errors::NotFound("No proto found for core index ", proto_index, - " in subgraph with uid ", uid); - } - *entry = std::unique_ptr<CompilationCacheEntryRef>( - new TpuEntryRefImpl(this, cache_entry, proto_index)); - return Status::OK(); -} - -Status TpuCompilationCacheExternal::Lookup( - const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) { - entry->reset(); - - profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", - /*level=*/2); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_proto_key_.find(proto_key); - if (iter == entries_by_proto_key_.end()) { - return errors::NotFound("No proto found for key ", proto_key); - } - CompiledSubgraph* cache_entry = iter->second.first; - int proto_index = iter->second.second; - *entry = std::unique_ptr<CompilationCacheEntryRef>( - new TpuEntryRefImpl(this, cache_entry, proto_index)); - return Status::OK(); -} - -Status TpuCompilationCacheExternal::CompileIfKeyAbsentHelper( - const TpuCompilationCacheKey& subgraph_key, - const SessionMetadata* session_metadata, - TpuCompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, - std::vector<CompiledSubgraph*>* removed_entries, - std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata, - const std::function<Status(TpuProgramGroup*)>& compile_function) { - profiler::TraceMe subgraph_lookup_traceme( - "TPU compilation cache subgraph lookup", - /*level=*/2); - - // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock - // for the lifetime of the object, see InitializeEntry() call below. - absl::MutexLock lock(&mu_); - - std::string cache_key = FindCacheKey(subgraph_key); - auto iter = cache_store_.find(cache_key); - bool is_new_key = iter == cache_store_.end(); - - const std::string session_name = SessionNameFromMetadata(session_metadata); - - CompiledSubgraph* entry = nullptr; - if (is_new_key) { - cache_key = ConstructCompilationCacheKey(subgraph_key); - TpuCompilationCacheMetrics::IncrementCacheLookupCount( - /*is_cache_hit=*/false, session_name); - const string msg = - strings::StrCat("TPU host compilation cache miss: cache_key(", - cache_key, "), session_name(", session_name, ")"); - - TRACESTRING(msg); - LOG(INFO) << msg; - - // Check if caller has disabled compilation. Set using - // internal::ScopedTpuCompileDisabler. - if (!IsTpuCompilationEnabled()) { - const string error_msg = strings::StrCat( - "[TpuCompilationDisabled]: Compilation cache miss, but compilation " - "disabled, session_name(", - session_name, ") Debug String: ", subgraph_key.debug_string); - if (VLOG_IS_ON(2)) { - VLOG(2) << "Cache Missed. Current cache entries: "; - for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) { - // TODO(henrytan): add DebugKey as cache_entry_debug_string to - // TpuCompilationCacheKey. - VLOG(2) << "Cache Debug Info: "; - VLOG(2) << it->second->cache_entry_debug_string; - } - } - - LOG_EVERY_N_SEC(WARNING, 30) << error_msg; - return errors::NotFound(error_msg); - } - - // The single ref on the newly-created entry is owned by the caller. - VLOG(1) << "Before adding new entry for key " << cache_key - << " with session_name( " << session_name << ");" - << "; cache is " << cache_store_.size() << " entries (" - << cache_size_ + marked_for_eviction_size_ << " bytes), " - << " marked for eviction " - << (cache_store_.size() - entries_by_last_use_.size()) - << " entries (" << marked_for_eviction_size_ << " bytes)."; - // Note that InitializeEntry() will Release/Reacquire mu_. - entry = InitializeEntry(cache_key, compile_function, subgraph_key); - TRACELITERAL("TPU host compilation cache: compilation done."); - - LOG(INFO) << strings::StrCat( - "TPU host compilation cache: compilation done for cache_key(", - cache_key, "), session_name(", session_name, ")"); - // If session_name is present, log some additional stats related to HBM - // here, so that they can be associated directly to the session. - if (!session_name.empty()) { - entry->tpu_program_group->LogProgramMemorySummary(); - } - } else { - TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name); - const string msg = - strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key, - "), session_name(", session_name, ")"); - TRACESTRING(msg); - VLOG(1) << msg; - VLOG(1) << "Before refreshing entry for key " << cache_key - << " with session_name( " << session_name << "); cache is " - << cache_store_.size() << " entries (" - << cache_size_ + marked_for_eviction_size_ << " bytes), " - << " marked for eviction " - << (cache_store_.size() - entries_by_last_use_.size()) - << " entries (" << marked_for_eviction_size_ << " bytes)."; - entry = iter->second; - // Make a new reference that is owned by the caller. - entry->Ref(); - // Block if necessary until the subgraph has been initialized. - mu_.Await(absl::Condition( - +[](CompiledSubgraph* e) { return e->initialized; }, entry)); - } - - // Let the caller know the uid of the entry. - *uid = entry->uid; - // Let the caller know the keys for each of the cached protos. - *proto_key = entry->proto_key; - *may_modify_variables = entry->tpu_program_group->may_modify_variables(); - *hlo_metadata = entry->tpu_program_group->hlo_metadatas(); - - // If the caller didn't supply a per_step_ref_holder then the caller is going - // to manually release the reference later via a call to Release(). - if (per_step_ref_holder == nullptr) { - ++entry->external_references; - } else { - // The caller wants its reference to be handed off to a per-step holder that - // will discard the reference when the step completes. - RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder); - TF_RET_CHECK(cast_ref_holder != nullptr); - cast_ref_holder->AddRef(entry); - } - - // Remove the old LRU-table entry if it wasn't already marked for eviction. - auto erased = entries_by_last_use_.erase(entry->last_use); - // Update the LRU table indicating this entry is the most recently used. - entry->last_use = use_counter_++; - entries_by_last_use_[entry->last_use] = entry; - if (erased == 0) { - // The entry had been marked for eviction, or is newly created. - LookupEntryMarkedForEviction(entry, removed_entries); - } - - // Log a little more verbosely when a key is added. - if (VLOG_IS_ON(1) || is_new_key) { - LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing") - << " entry for key " << cache_key << " with session_name " - << session_name << " cache is " << cache_store_.size() - << " entries (" << cache_size_ + marked_for_eviction_size_ - << " bytes), " - << " marked for eviction " - << (cache_store_.size() - entries_by_last_use_.size()) - << " entries (" << marked_for_eviction_size_ << " bytes)."; - } - return entry->initialization_status; -} - -tensorflow::Status TpuCompilationCacheExternal::CompileIfKeyAbsent( - const TpuCompilationCacheKey& cache_key, - const tensorflow::SessionMetadata* session_metadata, - TpuCompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, - std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata, - const std::function<tensorflow::Status(TpuProgramGroup*)>& - compile_function) { - std::vector<CompiledSubgraph*> removed_entries; - auto status = CompileIfKeyAbsentHelper( - cache_key, session_metadata, per_step_ref_holder, uid, proto_key, - may_modify_variables, &removed_entries, hlo_metadata, compile_function); - for (auto entry : removed_entries) { - UnloadAndDestroy(entry); - } - return status; -} - } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h index eff2afde108..2c75cb4d053 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -26,11 +26,14 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" @@ -40,37 +43,25 @@ limitations under the License. namespace tensorflow { namespace tpu { -const char kCompilationCacheResourceName[] = "tpu_compilation_cache"; -const char kCompilationCacheUnloaderResourceName[] = +constexpr char kCompilationCacheResourceName[] = "tpu_compilation_cache"; +constexpr char kCompilationCacheUnloaderResourceName[] = "tpu_compilation_cache_unloader"; -// Base class that holds references to compiled protos so that the protos are -// not garbage-collected before being used by execute ops. Use -// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete -// ref holder object. -class TpuCompilationRefHolder : public ResourceBase { - public: - ~TpuCompilationRefHolder() override = default; -}; - -class TpuCompilationCacheExternal : public ResourceBase { +class TpuCompilationCacheExternal : public TpuCompilationCacheInterface { public: using Status = ::stream_executor::port::Status; - explicit TpuCompilationCacheExternal(int64_t max_cache_size); - ~TpuCompilationCacheExternal() override; - TpuCompilationCacheExternal(const TpuCompilationCacheExternal&) = delete; - TpuCompilationCacheExternal& operator=(const TpuCompilationCacheExternal&) = - delete; + class EntryRefImpl + : public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> { + public: + EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, + int index); - Status CompileIfKeyAbsent( - const TpuCompilationCacheKey& cache_key, - const SessionMetadata* session_metadata, - TpuCompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, - std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata, - const std::function<tensorflow::Status(TpuProgramGroup*)>& - compile_function); + TpuCompilationCacheEntry get() override; + }; + + explicit TpuCompilationCacheExternal(int64 max_cache_size) + : TpuCompilationCacheInterface(max_cache_size) {} static TpuCompilationCacheKey CreateCompilationCacheKey( absl::string_view function_name, uint64 function_library_fingerprint, @@ -82,177 +73,7 @@ class TpuCompilationCacheExternal : public ResourceBase { string DebugString() const override { return "TpuCompilationCacheExternal"; } - // Makes a reference holder for this cache, that can be stored in the per-step - // resource manager and will ensure that compiled entries persist until the - // end of a step. - TpuCompilationRefHolder* MakePerStepRefHolder(); - - // Differences between MarkEntryForEviction and Release: - // There are two modes of managing cache entries: - // 1) LRU eviction + pinning; 2) manual. - // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent. - // Otherwise it is manual mode (mainly used by XRT). - // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache - // entries when callers know that they do not need them anymore. - // Release should only be used in mode 2) to explicitly remove an entry. - - // Mark the entry indexed by `subgraph_uid` for eviction. This should only be - // called if per_step_ref_holder was NOT nullptr in the corresponding call to - // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64 - // subgraph_uid). - Status MarkEntryForEviction(int64 subgraph_uid); - - // Manually discards a reference to the compiled subgraph. This should only be - // called if per_step_ref_holder was nullptr in the corresponding call to - // CompileIfKeyAbsent(subgraph_key, ...). - Status Release(int64 subgraph_uid); - - // Looks up an executable corresponding to the model-parallel core index of - // the subgraph represented by key. On success a pointer to an EntryRef - // holding the program is returned in entry. - Status Lookup(const string& proto_key, - std::unique_ptr<CompilationCacheEntryRef>* entry); - - // Looks up an executable corresponding to the model-parallel core index of - // the subgraph represented by uid. On success a pointer to an EntryRef - // holding the program is returned in entry. - Status Lookup(int64 uid, int proto_index, - std::unique_ptr<CompilationCacheEntryRef>* entry); - - // Mutates the main entry ref to point to the entry's subentry - // (for sharding/unsharding) or main entry (unchanged) representing the - // fetch target. The entry ref needs to point to the main entry before this - // call. - // - // If the requested subentry does not exist, the ref will point to a nullptr - // entry. - Status ToSubEntryRef(CompilationCacheEntryRef* entry, - CompilationCacheFetchTarget fetch_target) const; - private: - // Wrapper for a cache entry that holds a reference to the entry until the - // wrapper is deleted. This wrapper is the concrete type of - // CompilationCacheEntryRef returned by Lookup. - class TpuEntryRefImpl : public CompilationCacheEntryRef { - public: - TpuEntryRefImpl(TpuCompilationCacheExternal* parent, - CompiledSubgraph* entry, int index); - ~TpuEntryRefImpl() override; - - TpuCompilationCacheEntry get() override; - - // Mutates this ref to point to the entry's subentry (for - // sharding/unsharding) or main entry (unchanged) as specified by - // fetch_target. The refcount is kept unchanged, since we only track the - // refcount of the main entry. The entry ref needs to point to the main - // entry before this call. - // - // If the requested subentry does not exist, the ref will point to a nullptr - // entry, and the original entry will be unref'ed. - Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target); - - private: - TpuCompilationCacheExternal* parent_; // Not owned. - // A reference to entry_ is acquired in the constructor and released via - // parent->DiscardEntryRefs in the destructor. - CompiledSubgraph* entry_; - // The program in entry_ that is returned by the get method. - int index_; - }; - - // Private implementation of the generic CompilationRefHolder that knows about - // CompiledSubgraph entries. - class RefHolder : public TpuCompilationRefHolder { - public: - explicit RefHolder(TpuCompilationCacheExternal* parent) : parent_(parent) { - parent_->Ref(); - } - ~RefHolder() override { - // Release our reference to the parent. - parent_->Unref(); - } - - // Adds entry to the list of entries that will be released when the - // RefHolder is destroyed. Each entry is released via a call to - // parent_->DiscardEntryRefs. - void AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); } - - string DebugString() const override { - return "TpuCompilationCacheExternal::RefHolder"; - } - - private: - TpuCompilationCacheExternal* parent_; // Not owned. - std::vector<CompiledSubgraph*> entries_; - }; - - // The bulk of implementation of CompileIfKeyAbsent() with the exception - // of unloading programs that corresponds to possibly removed cache - // entries. The split helps to manage locking since we prefer to perform - // unloading without holding extra locks. - Status CompileIfKeyAbsentHelper( - const TpuCompilationCacheKey& subgraph_key, - const SessionMetadata* session_metadata, - TpuCompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, - std::vector<CompiledSubgraph*>* removed_entries, - std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata, - const std::function<Status(TpuProgramGroup*)>& compile_function); - - // This is called by the cache when entry is marked for eviction; by - // a RefHolder (via DiscardEntryRefs) when a step completes; and by - // an EntryRefImpl when it is destroyed. Releases one reference to entry - // if more than 1 remains. If only one reference is left, the entry is removed - // from cache_ and is returned to the caller; which must eventually call - // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef - // to avoid holding the lock during program unloading. - ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef( - CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Convenience method called by ~RefHolder without mu_ held. Calls - // DiscardEntryRef on every element of entries. - void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries); - - // Marks the oldest unmarked entry for eviction. Requires that there is at - // least one such entry. In case the evicted entry had only 1 reference it - // is removed from the cache and returned to the caller which must eventually - // call UnloadAndDestroy. - CompiledSubgraph* MarkOldestEntryForEviction() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Updates datastructures to indicate that entry, which had been marked for - // eviction, has been looked up. This is called by CompileIfKeyAbsent when an - // entry is newly created, or an entry that has been marked for eviction but - // not yet evicted is looked up. - // - // First the entry is unmarked for eviction, i.e. the cache gains a reference - // to entry, entry's last_use field is set to be the most recent value of - // use_counter_ and entries_by_last_use_ is updated accordingly. - // - // Next, the size of the cache is examined to see if any other entries need to - // be marked for eviction now that entry has been unmarked. While the total - // size of unmarked cached entries is greater than max_cache_size_, entries - // are marked for eviction in LRU order. The most recently used entry is never - // marked for eviction, so an entry larger than the max cache size will remain - // in the cache until it is replaced by something else. In case some entries - // actually were removed from the cache, they are a returned to the caller via - // removed_entries. The caller must eventually delete them by calling - // UnloadAndDestroy. - void LookupEntryMarkedForEviction( - CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Removes the entry with given key from cache. - size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Inserts the given key and entry to cache. - void InsertEntry(const std::string& key, - const TpuCompilationCacheKey& subgraph_key, - CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Returns the cache key matching given subgraph_key. - std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Creates a new entry by running initialize_programs and places it in the // cache to be looked up by key. The new entry is in the 'marked for eviction' // state (not present in entries_by_last_use_) and the caller is expected to @@ -261,61 +82,10 @@ class TpuCompilationCacheExternal : public ResourceBase { // **InitializeEntry releases mu_ during the call to initialize_programs.** CompiledSubgraph* InitializeEntry( const string& key, - const std::function<Status(TpuProgramGroup*)>& initialize_program, + const std::function<Status(TpuProgramGroupInterface*)>& + initialize_program, const TpuCompilationCacheKey& subgraph_key) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Unloads the program associated with the entry from all local devices - // and deletes the entry itself. It is assumed no one else has a reference - // to it and all related keys had already been removed from the cache. - // The call can perform device IO so no locks should be held while calling it. - void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_); - - // The maximum size of entries that are stored in the cache before entries are - // marked for eviction. - const int64 max_cache_size_; - - mutable absl::Mutex mu_; - // The total size of entries that are stored and not marked for eviction. - int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0; - - // The total size of entries that are marked for eviction. - int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0; - - // The value to assign to the last_use field of the next entry that is looked - // up. - int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0; - - // session_key_map_ and fingerprint_key_map_ are used for looking up the - // cache_ key matching a given subgraph key. When doing a lookup, check - // session_key_map_ first to avoid unnecessay fingerprint computation. - // Map from key prefix + session_handle to a cache_ key. - std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_); - - // Map from key prefix + fingerprint to a cache_ key. - std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_); - - // All the subgraph entries that can be looked up in the cache. An entry is - // marked for eviction iff it is present in cache_ and not in - // entries_by_last_use_. - std::unordered_map<string, CompiledSubgraph*> cache_store_ - ABSL_GUARDED_BY(mu_); - - // All the subgraph entries that can be looked up in the cache, indexed by - // uid. - absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_ - ABSL_GUARDED_BY(mu_); - - // All the protos that can be looked up in the cache, indexed by proto - // key. The value of the map is a subgraph and the index of the proto compiled - // for that subgraph. - std::unordered_map<string, std::pair<CompiledSubgraph*, int>> - entries_by_proto_key_ ABSL_GUARDED_BY(mu_); - - // Map from last_use to entry, used to mark entries for eviction in LRU - // order. If an entry's last_use counter is not present as a key in - // entries_by_last_use_ then the entry has been marked for eviction. - std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_); + ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override; }; } // namespace tpu diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc index f3e40df24dd..3b46f0f2d32 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc @@ -93,7 +93,9 @@ Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) { "use TpuCompilationCacheInterface::Release."); } - VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction"; + VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key + << " for eviction. Debug string: " + << subgraph_to_evict->cache_entry_debug_string; entries_by_last_use_.erase(subgraph_to_evict->last_use); cache_size_ -= subgraph_to_evict->total_size; marked_for_eviction_size_ += subgraph_to_evict->total_size; @@ -231,7 +233,9 @@ void TpuCompilationCacheInterface::DiscardEntryRefs( CompiledSubgraph* TpuCompilationCacheInterface::MarkOldestEntryForEviction() { CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second; - VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction"; + VLOG(1) << "Marking " << entry_to_mark->subgraph_key + << " for eviction. Debug string: " + << entry_to_mark->cache_entry_debug_string; entries_by_last_use_.erase(entry_to_mark->last_use); cache_size_ -= entry_to_mark->total_size; marked_for_eviction_size_ += entry_to_mark->total_size; @@ -291,7 +295,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent( const SessionMetadata* session_metadata, CompilationRefHolder* per_step_ref_holder, int64* uid, std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, - std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas, + absl::Span<const xla::HloProto* const>* hlo_metadatas, const std::function<Status(TpuProgramGroupInterface*)>& compile_function) { std::vector<CompiledSubgraph*> removed_entries; auto status = CompileIfKeyAbsentHelper( @@ -328,7 +332,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( CompilationRefHolder* per_step_ref_holder, int64* uid, std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, std::vector<CompiledSubgraph*>* removed_entries, - std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas, + absl::Span<const xla::HloProto* const>* hlo_metadatas, const std::function<Status(TpuProgramGroupInterface*)>& compile_function) { CompiledSubgraph* entry = nullptr; @@ -388,7 +392,8 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( TRACELITERAL("TPU host compilation cache: compilation done."); LOG(INFO) << strings::StrCat( "TPU host compilation cache: compilation done for cache_key(", - cache_key, "), session_name(", session_name, ")"); + cache_key, "), session_name(", session_name, "), subgraph_key(", + subgraph_key.debug_string, ")"); // If session_name is present, log some additional stats related to HBM // here, so that they can be associated directly to the session. if (!session_name.empty()) { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h new file mode 100644 index 00000000000..f92893b78f6 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h @@ -0,0 +1,355 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h" +#include "tensorflow/core/tpu/kernels/trace_util.h" + +namespace tensorflow { +namespace tpu { + +// Base class that holds references to compiled protos so that the protos are +// not garbage-collected before being used by execute ops. Use +// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete +// ref holder object. +class CompilationRefHolder : public ResourceBase { + public: + ~CompilationRefHolder() override = default; +}; + +// Base class for a reference to a cached tpu program. A unique_ptr to a +// CompilationCacheEntryRef is returned by all the cache Lookup methods below, +// and ensures the underlying proto is not garbage-collected until the client +// discards the ptr. +template <typename CacheEntryType> +class CompilationCacheEntryRef { + public: + virtual ~CompilationCacheEntryRef() = default; + + // Returns a CompilationCacheEntry that should not be used beyond the lifetime + // of the tpu::CompilationCacheEntryRef. + virtual CacheEntryType get() = 0; + + // Mutates this ref to point to the entry's subentry (for + // sharding/unsharding) or main entry (unchanged) as specified by + // fetch_target. The refcount is kept unchanged, since we only track the + // refcount of the main entry. The entry ref needs to point to the main + // entry before this call. + // + // If the requested subentry does not exist, the ref will point to a nullptr + // entry, and the original entry will be unref'ed. + virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0; +}; + +class TpuCompilationCacheInterface : public ResourceBase { + public: + explicit TpuCompilationCacheInterface(int64 max_cache_size); + ~TpuCompilationCacheInterface() override; + + // Ensures there is an entry for key present in the cache. By the time + // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache + // for key, and that entry will remain valid at least until + // per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a + // key that is not in the cache will evaluate compile_function to compute the + // value to use in the entry. Subsequent calls with the same key will block + // until compile_function completes. Other cache reads and inserts may proceed + // on other threads while compile_function is executing. If + // per_step_ref_holder is nullptr then the caller is responsible for calling + // Release(subgraph_key) to manually discard its reference to the compiled + // program, once the caller will not look up the compiled program again. + // + // compile_function should compile the subgraph represented by key and fill in + // one TPUExecutableProto per model-parallel core into its passed argument. It + // should return OK if and only if compilation succeeds. The executable proto + // vector will be discarded on non-OK status. + Status CompileIfKeyAbsent( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + CompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, + absl::Span<const xla::HloProto* const>* hlo_metadatas, + const std::function<Status(TpuProgramGroupInterface*)>& compile_function); + + // Differences between MarkEntryForEviction and Release: + // There are two modes of managing cache entries: + // 1) LRU eviction + pinning; 2) manual. + // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent. + // Otherwise it is manual mode (mainly used by XRT). + // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache + // entries when callers know that they do not need them anymore. + // Release should only be used in mode 2) to explicitly remove an entry. + + // Mark the entry indexed by `subgraph_uid` for eviction. This should only be + // called if per_step_ref_holder was NOT nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64 + // subgraph_uid). + Status MarkEntryForEviction(int64 subgraph_uid); + + // Manually discards a reference to the compiled subgraph. This should only be + // called if per_step_ref_holder was nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). + Status Release(int64 subgraph_uid); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by key. On success a pointer to an EntryRef + // holding the program is returned in entry. + template <typename CacheEntryRef, typename CacheEntryRefImpl> + Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by uid. On success a pointer to an EntryRef + // holding the program is returned in entry. + template <typename CacheEntryRef, typename CacheEntryRefImpl> + Status Lookup(int64 uid, int proto_index, + std::unique_ptr<CacheEntryRef>* entry); + + // Looks up the subgraph represented by uid, and returns the vector of keys, + // one per core, corresponding to that subgraph. + Status GetKeysFromUid(int64 uid, std::vector<string>* keys); + + // Makes a reference holder for this cache, that can be stored in the per-step + // resource manager and will ensure that compiled entries persist until the + // end of a step. + CompilationRefHolder* MakePerStepRefHolder(); + + // Convenience method called by ~RefHolder without mu_ held. Calls + // DiscardEntryRef on every element of entries. + void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries); + + string DebugString() const override { return "TpuCompilationCacheBase"; } + + protected: + std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) { + if (!key.has_guaranteed_const) { + return key.prefix; + } + return absl::StrCat(key.prefix, "|", key.session_handle, "|", + key.guaranteed_const_fingerprint()); + } + + // Private implementation of the generic CompilationRefHolder that knows about + // CompiledSubgraph entries. + class RefHolder : public CompilationRefHolder { + public: + explicit RefHolder(TpuCompilationCacheInterface* parent); + ~RefHolder() override; + + // Adds entry to the list of entries that will be released when the + // RefHolder is destroyed. Each entry is released via a call to + // parent_->DiscardEntryRefs. + void AddRef(CompiledSubgraph* entry); + + string DebugString() const override; + + private: + TpuCompilationCacheInterface* parent_; // Not owned. + std::vector<CompiledSubgraph*> entries_; + }; + + // The bulk of implementation of CompileIfKeyAbsent() with the exception + // of unloading programs that corresponds to possibly removed cache + // entries. The split helps to manage locking since we prefer to perform + // unloading without holding extra locks. + Status CompileIfKeyAbsentHelper( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + CompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector<string>* proto_key, std::vector<bool>* may_modify_variables, + std::vector<CompiledSubgraph*>* removed_entries, + absl::Span<const xla::HloProto* const>* hlo_metadatas, + const std::function<Status(TpuProgramGroupInterface*)>& compile_function); + + // This is called by the cache when entry is marked for eviction; by + // a RefHolder (via DiscardEntryRefs) when a step completes; and by + // an EntryRefImpl when it is destroyed. Releases one reference to entry + // if more than 1 remains. If only one reference is left, the entry is removed + // from cache_ and is returned to the caller; which must eventually call + // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef + // to avoid holding the lock during program unloading. + ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef( + CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Marks the oldest unmarked entry for eviction. Requires that there is at + // least one such entry. In case the evicted entry had only 1 reference it + // is removed from the cache and returned to the caller which must eventually + // call UnloadAndDestroy. + ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Updates datastructures to indicate that entry, which had been marked for + // eviction, has been looked up. This is called by CompileIfKeyAbsent when an + // entry is newly created, or an entry that has been marked for eviction but + // not yet evicted is looked up. + // + // First the entry is unmarked for eviction, i.e. the cache gains a reference + // to entry, entry's last_use field is set to be the most recent value of + // use_counter_ and entries_by_last_use_ is updated accordingly. + // + // Next, the size of the cache is examined to see if any other entries need to + // be marked for eviction now that entry has been unmarked. While the total + // size of unmarked cached entries is greater than max_cache_size_, entries + // are marked for eviction in LRU order. The most recently used entry is never + // marked for eviction, so an entry larger than the max cache size will remain + // in the cache until it is replaced by something else. In case some entries + // actually were removed from the cache, they are a returned to the caller via + // removed_entries. The caller must eventually delete them by calling + // UnloadAndDestroy. + void LookupEntryMarkedForEviction( + CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Removes the entry with given key from cache. + size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Inserts the given key and entry to cache. + void InsertEntry(const string& key, CompiledSubgraph* entry) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the cache key matching given subgraph_key. + string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Creates a new entry by running initialize_programs and places it in the + // cache to be looked up by key. The new entry is in the 'marked for eviction' + // state (not present in entries_by_last_use_) and the caller is expected to + // call LookupEntryMarkedForEviction after InitializeEntry. + // + // **InitializeEntry releases mu_ during the call to initialize_programs.** + virtual CompiledSubgraph* InitializeEntry( + const string& key, + const std::function<Status(TpuProgramGroupInterface*)>& + initialize_programs, + const TpuCompilationCacheKey& subgraph_key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + + // Unloads the program associated with the entry from all local devices + // and deletes the entry itself. It is assumed no one else has a reference + // to it and all related keys had already been removed from the cache. + // The call can perform device IO so no locks should be held while calling it. + void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_); + + // The maximum size of entries that are stored in the cache before entries are + // marked for eviction. + const int64 max_cache_size_; + // Mutex to protect access to shared resources under multi-threading + // environment. + absl::Mutex mu_; + // The total size of entries that are stored and not marked for eviction. + int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0; + // The total size of entries that are marked for eviction. + int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0; + // The value to assign to the last_use field of the next entry that is looked + // up. + int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0; + // session_key_map_ and fingerprint_key_map_ are used for looking up the + // cache_ key matching a given subgraph key. When doing a lookup, check + // session_key_map_ first to avoid unnecessay fingerprint computation. + // Map from key prefix + session_handle to a cache_ key. + absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_); + // Map from key prefix + fingerprint to a cache_ key. + absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_); + // All the subgraph entries that can be looked up in the cache. An entry is + // marked for eviction iff it is present in cache_ and not in + // entries_by_last_use_. + std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_); + // All the subgraph entries that can be looked up in the cache, indexed by + // uid. + absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_ + ABSL_GUARDED_BY(mu_); + // All the protos that can be looked up in the cache, indexed by proto + // key. The value of the map is a subgraph and the index of the proto compiled + // for that subgraph. + std::unordered_map<string, std::pair<CompiledSubgraph*, int>> + entries_by_proto_key_ ABSL_GUARDED_BY(mu_); + // Map from last_use to entry, used to mark entries for eviction in LRU + // order. If an entry's last_use counter is not present as a key in + // entries_by_last_use_ then the entry has been marked for eviction. + std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_); + + TpuCompilationCacheMetrics tpu_compilation_cache_metrics_; + + private: + TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete; + TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) = + delete; +}; + +template <typename CacheEntryRef, typename CacheEntryRefImpl> +Status TpuCompilationCacheInterface::Lookup( + int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme( + "TPU compilation cache proto lookup by uid", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_uid_.find(uid); + if (iter == entries_by_uid_.end()) { + return errors::NotFound("No subgraph found for uid ", uid); + } + CompiledSubgraph* cache_entry = iter->second; + if (proto_index < 0 || + proto_index >= cache_entry->tpu_program_group->program_count()) { + return errors::NotFound("No proto found for core index ", proto_index, + " in subgraph with uid ", uid); + } + *entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index); + return Status::OK(); +} + +template <typename CacheEntryRef, typename CacheEntryRefImpl> +Status TpuCompilationCacheInterface::Lookup( + const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_proto_key_.find(proto_key); + if (iter == entries_by_proto_key_.end()) { + return errors::NotFound("No proto found for key ", proto_key); + } + CompiledSubgraph* cache_entry = iter->second.first; + int proto_index = iter->second.second; + *entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index); + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc index 8b2e832a69e..9285dff62ce 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc @@ -42,7 +42,7 @@ std::string GetName(CompilationCacheFetchTarget target) { } // namespace TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup( - TpuCompilationCacheExternal* cache) + TpuCompilationCacheInterface* cache) : cache_(cache) {} TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() { @@ -50,17 +50,19 @@ TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() { } Status TpuCompilationCacheLocalLookup::Lookup( - const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry, + const string& proto_key, + std::unique_ptr<TpuCompilationCacheEntryRef>* entry, CompilationCacheFetchTarget fetch_target) { profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup", /*level=*/2); - Status s = cache_->Lookup(proto_key, entry); + Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>( + proto_key, entry); VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status " << s; if (!s.ok()) { return s; } - s = cache_->ToSubEntryRef(entry->get(), fetch_target); + s = (*entry)->ToSubEntryRef(fetch_target); VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " << s; @@ -69,17 +71,18 @@ Status TpuCompilationCacheLocalLookup::Lookup( Status TpuCompilationCacheLocalLookup::Lookup( int64 uid, int proto_index, - std::unique_ptr<CompilationCacheEntryRef>* entry, + std::unique_ptr<TpuCompilationCacheEntryRef>* entry, CompilationCacheFetchTarget fetch_target) { profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid", /*level=*/2); - Status s = cache_->Lookup(uid, proto_index, entry); + Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>( + uid, proto_index, entry); VLOG(1) << "Looked up uid " << uid << ", index " << proto_index << " in local subgraph cache status " << s; if (!s.ok()) { return s; } - s = cache_->ToSubEntryRef(entry->get(), fetch_target); + s = (*entry)->ToSubEntryRef(fetch_target); VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " << s; return s; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h index 0d068e1bdd1..21ca74c46a8 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h @@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ -#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" namespace tensorflow { namespace tpu { @@ -28,6 +30,11 @@ namespace tpu { // and when they need to communicate over RPC. class TpuCompilationCacheLookup : public ResourceBase { public: + using TpuCompilationCacheEntryRef = + ::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>; + using EntryRefImpl = + ::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl; + ~TpuCompilationCacheLookup() override = default; // Looks up an executable corresponding to the model-parallel core index of @@ -42,11 +49,11 @@ class TpuCompilationCacheLookup : public ResourceBase { // fetch_target requests one of them, then after this call // (*entry)->get().get_executable() will return nullptr. virtual Status Lookup(const string& proto_key, - std::unique_ptr<CompilationCacheEntryRef>* entry, + std::unique_ptr<TpuCompilationCacheEntryRef>* entry, CompilationCacheFetchTarget fetch_target) = 0; virtual Status Lookup(const string& proto_key, - std::unique_ptr<CompilationCacheEntryRef>* entry) { + std::unique_ptr<TpuCompilationCacheEntryRef>* entry) { return Lookup(proto_key, std::move(entry), CompilationCacheFetchTarget::MAIN); } @@ -56,33 +63,30 @@ class TpuCompilationCacheLookup : public ResourceBase { // returned in program. The wrapper is guaranteed to be valid only during the // execution of the Op requesting the proto. virtual Status Lookup(int64 uid, int proto_index, - std::unique_ptr<CompilationCacheEntryRef>* entry, + std::unique_ptr<TpuCompilationCacheEntryRef>* entry, CompilationCacheFetchTarget fetch_target) = 0; virtual Status Lookup(int64 uid, int proto_index, - std::unique_ptr<CompilationCacheEntryRef>* entry) { + std::unique_ptr<TpuCompilationCacheEntryRef>* entry) { return Lookup(uid, proto_index, std::move(entry), CompilationCacheFetchTarget::MAIN); } }; -// Forward declaration to break cycle dependency graph. -class TpuCompilationCacheExternal; - // Class for looking up ISA protos when the execute and compile Op are in the // same address space. The proto is simply looked up in the compilation cache, // without any serialization taking place. class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { public: - explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheExternal* cache); + explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache); ~TpuCompilationCacheLocalLookup() override; Status Lookup(const string& proto_key, - std::unique_ptr<CompilationCacheEntryRef>* entry, + std::unique_ptr<TpuCompilationCacheEntryRef>* entry, CompilationCacheFetchTarget fetch_target) override; Status Lookup(int64 uid, int proto_index, - std::unique_ptr<CompilationCacheEntryRef>* entry, + std::unique_ptr<TpuCompilationCacheEntryRef>* entry, CompilationCacheFetchTarget fetch_target) override; string DebugString() const override; @@ -90,10 +94,10 @@ class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { private: // The subgraph compilation cache, in the same process address space where the // lookups are happening. - TpuCompilationCacheExternal* cache_; + TpuCompilationCacheInterface* cache_; }; } // namespace tpu } // namespace tensorflow -#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index c8faba1d975..7ab1c9b8027 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_defs.h" diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index 0f21e458828..36f9fa96db1 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc index 43452b912ec..ecda2ef062e 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.cc +++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc @@ -209,15 +209,8 @@ xla::HloProto TpuProgramGroup::hlo_metadata(int core_index) const { serialized_hlo_proto); } -std::vector<std::shared_ptr<const xla::HloProto>> -TpuProgramGroup::hlo_metadatas() const { - const size_t metadata_count = program_count(); - std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadatas; - hlo_metadatas.resize(metadata_count); - for (size_t i = 0; i < metadata_count; ++i) { - hlo_metadatas[i] = std::make_shared<const xla::HloProto>(hlo_metadata(i)); - } - return hlo_metadatas; +absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const { + return absl::MakeConstSpan(hlo_metadatas_); } } // namespace tpu diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h index de8256a9e59..0ade58e6daa 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -139,11 +139,15 @@ class TpuProgramGroup : public TpuProgramGroupInterface { const xla::HloProto& hlo_metadata() const { return hlo_metadata_; } void set_hlo_metadata(const xla::HloProto& hlo_metadata) { hlo_metadata_ = hlo_metadata; + + // TODO(henrytan): initialize hlo_metadatas_ for multi program support. + if (hlo_metadatas_.empty()) { + hlo_metadatas_.push_back(&hlo_metadata_); + } } xla::HloProto hlo_metadata(int core_index) const; - std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadatas() - const override; + absl::Span<const xla::HloProto* const> hlo_metadatas() const override; private: std::vector<bool> may_modify_variables_; @@ -153,6 +157,7 @@ class TpuProgramGroup : public TpuProgramGroupInterface { TPUExecutableInfoProto executable_info_; TPUHostTransferInfoProto host_transfer_info_; xla::HloProto hlo_metadata_; + std::vector<const xla::HloProto*> hlo_metadatas_; }; } // namespace tpu diff --git a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h index a4f74fb750d..8d8dd5a8786 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h @@ -44,9 +44,9 @@ class TpuProgramGroupInterface { // Logs program memory summary. virtual bool LogProgramMemorySummary() = 0; - // Hlo metadatas. - virtual std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadatas() - const = 0; + // Hlo metadatas. The pointers can only be used as long as the cache entry is + // referenced. + virtual absl::Span<const xla::HloProto* const> hlo_metadatas() const = 0; // Boolean array to indicate if the modification of variables are // allowed.