TPU library internal change.

PiperOrigin-RevId: 316171774
Change-Id: I85b3c2639c734e391692f568bffae0efe116e9af
This commit is contained in:
Henry Tan 2020-06-12 13:57:03 -07:00 committed by TensorFlower Gardener
parent 9130ba73e6
commit 1e3eddd152
11 changed files with 848 additions and 587 deletions

View File

@ -19,9 +19,9 @@ cc_library(
deps = [
":tpu_compile_op_support",
":tpu_mesh_state_interface",
":tpu_program_group_interface",
":tpu_util",
":tpu_util_hdrs",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/tf2xla:tf2xla_util",
@ -30,16 +30,16 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
# "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
@ -157,28 +157,14 @@ cc_library(
"tpu_compilation_cache_entry.h",
],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_proto_cc",
":tpu_executable_info_proto_cc",
":tpu_program_group",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/platform:casts",
],
)
cc_library(
name = "tpu_compilation_cache_entry_impl",
srcs = [],
hdrs = ["tpu_compilation_cache_entry_impl.h"],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_interface",
":tpu_executable_info_proto_cc",
],
)
cc_library(
name = "tpu_compilation_cache_lookup",
srcs = ["tpu_compilation_cache_lookup.cc"],
@ -188,7 +174,6 @@ cc_library(
deps = [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_external",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_proto_cc",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/platform:status",
@ -262,35 +247,6 @@ cc_library(
],
)
cc_library(
name = "tpu_compilation_cache_interface",
srcs = ["tpu_compilation_cache_interface.cc"],
hdrs = ["tpu_compilation_cache_interface.h"],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_key",
":tpu_compilation_cache_metrics_hdrs",
":tpu_compilation_cache_proto_cc",
":tpu_util",
":tpu_util_hdrs",
":trace_util_hdrs",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/platform:casts", # buildcleaner: keep
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
alwayslink = 1,
)
cc_library(
name = "tpu_compilation_cache_external",
srcs = ["tpu_compilation_cache_external.cc"],
@ -300,8 +256,6 @@ cc_library(
deps = [
":compiled_subgraph",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_entry_impl",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_key",
":tpu_compilation_cache_metrics", # buildcleaner: keep
":tpu_compilation_cache_metrics_hdrs",
@ -401,7 +355,6 @@ cc_library(
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,

View File

@ -25,9 +25,6 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
// Forward declaration to avoid circular dependency.
class TpuCompilationCacheInterface;
// Cache for compiled TPU program.
//
// Each key identifies a unique subgraph, and the value is the vector of
@ -103,7 +100,10 @@ class TpuCompilationCacheInterface;
// unmarked and set to most recently used.
//
struct CompiledSubgraph : public core::RefCounted {
TpuCompilationCacheInterface* parent = nullptr; // Not owned.
// TODO(henrytan): once `TpuCompilationCache` and
// `TpuCompilationCacheExternal` inherits from `TpuCompilationCacheInterface`
// update void* with `TpuCompilationCacheInterface`
void* parent = nullptr; // Not owned.
bool initialized = false;
@ -145,7 +145,7 @@ struct CompiledSubgraph : public core::RefCounted {
// owning main entry.
CompiledSubgraph* main_entry = nullptr;
// Compiled TPU program group.
// Compiled Tpu program.
std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
// Computes total program size.

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
// A version of `CompilationCacheEntry` to access Tpu binary program
// A version of `CompilationCacheEntry` that exposes Tpu binary program
// `XLA_TpuProgram`.
class TpuCompilationCacheEntry {
public:
@ -42,6 +42,28 @@ class TpuCompilationCacheEntry {
int core_index_;
};
// Base class for a reference to a cached proto. A unique_ptr to a
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
// and ensures the underlying proto is not garbage-collected until the client
// discards the ptr.
class CompilationCacheEntryRef {
public:
virtual ~CompilationCacheEntryRef() = default;
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
// of the CompilationCacheEntryRef.
virtual TpuCompilationCacheEntry get() = 0;
};
// Base class that holds references to compiled protos so that the protos are
// not garbage-collected before being used by execute ops. Use
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
// ref holder object.
class CompilationRefHolder : public ResourceBase {
public:
~CompilationRefHolder() override = default;
};
} // namespace tpu
} // namespace tensorflow

View File

@ -1,108 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
namespace tensorflow {
namespace tpu {
// Wrapper for a cache entry that holds a reference to the entry until the
// wrapper is deleted. This wrapper is the concrete type of
// CompilationCacheEntryRef returned by Lookup.
template <typename CacheEntryType>
class CompilationCacheEntryRefImpl
: public CompilationCacheEntryRef<CacheEntryType> {
public:
CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
CompiledSubgraph* entry, int index);
~CompilationCacheEntryRefImpl() override;
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
protected:
TpuCompilationCacheInterface* parent_; // Not owned.
// A reference to entry_ is acquired in the constructor and released via
// parent->DiscardEntryRefs in the destructor.
CompiledSubgraph* entry_;
// The index of the program in entry_ that is returned by the get method.
int index_;
};
template <typename CacheEntryType>
CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
: parent_(parent), entry_(entry), index_(index) {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
entry_->Ref();
} else {
// This is a sharding/unsharding entry nested in a main entry. Only
// refcount the main entry.
entry_->main_entry->Ref();
}
}
template <typename CacheEntryType>
CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
parent_->DiscardEntryRefs({entry_});
} else {
parent_->DiscardEntryRefs({entry_->main_entry});
}
}
template <typename CacheEntryType>
Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
CompilationCacheFetchTarget fetch_target) {
CompiledSubgraph* target = nullptr;
switch (fetch_target) {
case CompilationCacheFetchTarget::MAIN:
target = entry_;
break;
case CompilationCacheFetchTarget::SHARDING:
target = entry_->sharding_entry.get();
break;
case CompilationCacheFetchTarget::UNSHARDING:
target = entry_->unsharding_entry.get();
break;
default:
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
}
if (target == nullptr) {
// Cache entry does not have an unsharding subentry. Unref and replace
// with nullptr.
parent_->DiscardEntryRefs({entry_});
}
// Otherwise, since the refcount is always on the main entry, we don't
// need ref/unref.
entry_ = target;
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_

View File

@ -50,6 +50,14 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
entry->initialized = true;
}
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
if (!key.has_guaranteed_const) {
return key.prefix;
}
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
key.guaranteed_const_fingerprint());
}
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
// data to compute the fingerprint.
std::string GuaranteedConstFingerprint(
@ -115,32 +123,85 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
} // namespace
TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
: CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
index) {}
TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
if (entry_ == nullptr) {
// Create an empty entry if the entry is nullptr. This corresponds to
// non-existing sharding/unsharding entries.
return TpuCompilationCacheEntry();
TpuCompilationCacheExternal::TpuCompilationCacheExternal(int64_t max_cache_size)
: max_cache_size_(max_cache_size) {
if (max_cache_size < 0) {
LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
}
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
}
TpuCompilationCacheExternal::~TpuCompilationCacheExternal() {
VLOG(1) << "TpuCompilationCacheExternal::~TpuCompilationCacheExternal()";
// A buggy client may be holding onto a reference, or a client might have
// crashed while holding onto a reference. In either case, discard all
// outstanding client references to avoid leaking storage.
for (const auto& entry : entries_by_uid_) {
while (entry.second->external_references > 0) {
TF_CHECK_OK(Release(entry.first));
}
}
while (!entries_by_last_use_.empty()) {
UnloadAndDestroy(MarkOldestEntryForEviction());
}
// By the time the cache is deleted all reference holders should have already
// been deleted, since they were holding references to the cache. So all
// entries should be gone at this point.
CHECK_EQ(cache_store_.size(), 0);
CHECK_EQ(entries_by_uid_.size(), 0);
CHECK_EQ(entries_by_proto_key_.size(), 0);
CHECK_EQ(cache_size_, 0);
CHECK_EQ(marked_for_eviction_size_, 0);
}
std::string TpuCompilationCacheExternal::FindCacheKey(
const TpuCompilationCacheKey& subgraph_key) const {
if (!subgraph_key.has_guaranteed_const) {
return subgraph_key.prefix;
}
auto iter = session_key_map_.find(
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
if (iter != session_key_map_.end()) {
return iter->second;
}
iter = fingerprint_key_map_.find(strings::StrCat(
subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
if (iter != session_key_map_.end()) {
return iter->second;
}
VLOG(1) << "No matching cache key found for key "
<< ConstructCompilationCacheKey(subgraph_key);
return "";
}
void TpuCompilationCacheExternal::InsertEntry(
const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
CompiledSubgraph* entry) {
entry->parent = this;
entry->subgraph_key = cache_key;
entry->uid = get_uid();
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
entry->cache_entry_debug_string = subgraph_key.prefix;
VLOG(1) << "Cache Initializing Entry Session Debug "
<< entry->cache_entry_debug_string;
if (!subgraph_key.has_guaranteed_const) {
return;
}
session_key_map_.insert(std::make_pair(
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
cache_key));
fingerprint_key_map_.insert(std::make_pair(
strings::StrCat(subgraph_key.prefix,
subgraph_key.guaranteed_const_fingerprint()),
cache_key));
}
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
const string& key,
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
const std::function<Status(TpuProgramGroup*)>& initialize_program,
const TpuCompilationCacheKey& subgraph_key) {
CompiledSubgraph* main_entry = new CompiledSubgraph();
main_entry->parent = this;
main_entry->subgraph_key = key;
main_entry->uid = get_uid();
// TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
main_entry->cache_entry_debug_string = subgraph_key.prefix;
VLOG(1) << "Cache Initializing Entry Session Debug "
<< main_entry->cache_entry_debug_string;
// Add the entry to the cache, with size zero since there are no compiled
// programs in it. Once the subgraph has been compiled,
@ -151,7 +212,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
// who created the entry. A second reference, owned by the cache, will be
// added below since we leave the entry in the 'marked for eviction' state
// here.
InsertEntry(key, main_entry);
InsertEntry(key, subgraph_key, main_entry);
// Initialize the programs outside the lock so that other cache operations
// can proceed during the (potentially lengthy) initialization.
@ -259,5 +320,470 @@ TpuCompilationCacheExternal::CreateCompilationCacheKey(
}
return key;
}
TpuCompilationRefHolder* TpuCompilationCacheExternal::MakePerStepRefHolder() {
return new RefHolder(this);
}
Status TpuCompilationCacheExternal::MarkEntryForEviction(int64 subgraph_uid) {
profiler::TraceMe key_release_traceme(
"TPU compilation cache possibly evict uid",
/*level=*/2);
CompiledSubgraph* deleted_entry = nullptr;
{
absl::MutexLock lock(&mu_);
auto iter = entries_by_uid_.find(subgraph_uid);
if (iter == entries_by_uid_.end()) {
// If already evicted, return ok.
return Status::OK();
}
// Mark entry for eviction.
CompiledSubgraph* subgraph_to_evict = iter->second;
// If there are external references, should not use this API.
if (subgraph_to_evict->external_references != 0) {
return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
" external_references greater than zero. Should "
"use TpuCompilationCache::Release.");
}
VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
entries_by_last_use_.erase(subgraph_to_evict->last_use);
cache_size_ -= subgraph_to_evict->total_size;
marked_for_eviction_size_ += subgraph_to_evict->total_size;
// Evict if refcount exactly one, otherwise only discard cache's reference
// to the entry while the actual eviction will happen when refholder's
// references go away.
deleted_entry = DiscardEntryRef(subgraph_to_evict);
VLOG(1) << "After possibly evicting entry " << subgraph_uid
<< " refs cache is " << cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_
<< " bytes), marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
// Unload from device cache if entry is evicted from host cache.
UnloadAndDestroy(deleted_entry);
return Status::OK();
}
Status TpuCompilationCacheExternal::Release(int64 subgraph_uid) {
profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
/*level=*/2);
CompiledSubgraph* deleted_entry = nullptr;
{
absl::MutexLock lock(&mu_);
auto iter = entries_by_uid_.find(subgraph_uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No cache entry found for uid ", subgraph_uid);
}
CHECK_GT(iter->second->external_references, 0);
--iter->second->external_references;
deleted_entry = DiscardEntryRef(iter->second);
VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
<< cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_
<< " bytes), marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
UnloadAndDestroy(deleted_entry);
return Status::OK();
}
void TpuCompilationCacheExternal::UnloadAndDestroy(CompiledSubgraph* entry) {
if (!entry) return;
CHECK(entry->RefCountIsOne());
entry->tpu_program_group->UnloadAndDestroyPrograms();
entry->Unref();
}
size_t TpuCompilationCacheExternal::RemoveEntry(const string& key) {
auto erased = cache_store_.erase(key);
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
auto parsed_key_or_status = ParseCompilationCacheKey(key);
CHECK(parsed_key_or_status.status().ok());
const TpuCompilationCacheKey parsed_key =
parsed_key_or_status.ConsumeValueOrDie();
if (!parsed_key.has_guaranteed_const) {
return erased;
}
session_key_map_.erase(
strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
fingerprint_key_map_.erase(strings::StrCat(
parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
return erased;
}
ABSL_MUST_USE_RESULT CompiledSubgraph*
TpuCompilationCacheExternal::DiscardEntryRef(CompiledSubgraph* entry) {
if (entry->RefCountIsOne()) {
// The last reference to this entry is going away, so really delete it from
// the cache in such a way that it can't be restored by being looked up
// again.
// Sanity-check that it has been marked for eviction.
CHECK(entries_by_last_use_.find(entry->last_use) ==
entries_by_last_use_.end());
// Update the counter tracking how much space is taken up by entries that
// are marked for eviction.
marked_for_eviction_size_ -= entry->total_size;
// Remove the entry from the cache.
auto erased = RemoveEntry(entry->subgraph_key);
if (erased == 0) {
LOG(FATAL) << "Tried to discard nonexistent cache entry";
}
erased = entries_by_uid_.erase(entry->uid);
CHECK_EQ(erased, 1);
for (const string& key : entry->proto_key) {
erased = entries_by_proto_key_.erase(key);
CHECK_EQ(erased, 1);
}
// The actual deletion will happen outside the lock in UnloadAndDestroy().
return entry;
}
entry->Unref();
return nullptr;
}
void TpuCompilationCacheExternal::DiscardEntryRefs(
gtl::ArraySlice<CompiledSubgraph*> entries) {
std::vector<CompiledSubgraph*> removed_entries;
{
absl::MutexLock lock(&mu_);
for (auto entry : entries) {
removed_entries.push_back(DiscardEntryRef(entry));
}
VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
<< " entries (" << cache_size_ + marked_for_eviction_size_
<< " bytes), marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
for (auto removed_entry : removed_entries) {
UnloadAndDestroy(removed_entry);
}
}
ABSL_MUST_USE_RESULT CompiledSubgraph*
TpuCompilationCacheExternal::MarkOldestEntryForEviction() {
CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
entries_by_last_use_.erase(entry_to_mark->last_use);
cache_size_ -= entry_to_mark->total_size;
marked_for_eviction_size_ += entry_to_mark->total_size;
// Discard the cache's reference to entry. If steps are holding onto
// references to entry it won't be deleted until the last step holding it
// completes. It stays in the cache in the meantime and can be resurrected
// by a call to CompileIfKeyAbsent if that occurs before the last reference
// expires.
return DiscardEntryRef(entry_to_mark);
}
void TpuCompilationCacheExternal::LookupEntryMarkedForEviction(
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
// The entry was previously marked for eviction (or is newly created) so
// unmark it. Add a reference (owned by the cache), update the cache size, and
// mark something old for eviction if necessary.
entry->Ref();
marked_for_eviction_size_ -= entry->total_size;
cache_size_ += entry->total_size;
// Mark the least-recently-used non-marked entry for eviction. Never mark the
// most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
// which means there's only one entry not already marked for eviction), so
// that an entry persists in the cache even if it is larger than the allocated
// cache size.
while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
if (auto entry_to_evict = MarkOldestEntryForEviction()) {
removed_entries->push_back(entry_to_evict);
}
}
}
Status TpuCompilationCacheExternal::ToSubEntryRef(
CompilationCacheEntryRef* entry,
CompilationCacheFetchTarget fetch_target) const {
return static_cast<TpuEntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
}
TpuCompilationCacheExternal::TpuEntryRefImpl::TpuEntryRefImpl(
TpuCompilationCacheExternal* parent, CompiledSubgraph* entry, int index)
: parent_(parent), entry_(entry), index_(index) {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
entry_->Ref();
} else {
// This is a sharding/unsharding entry nested in a main entry. Only refcount
// the main entry.
entry_->main_entry->Ref();
}
}
TpuCompilationCacheExternal::TpuEntryRefImpl::~TpuEntryRefImpl() {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
parent_->DiscardEntryRefs({entry_});
} else {
parent_->DiscardEntryRefs({entry_->main_entry});
}
}
TpuCompilationCacheEntry TpuCompilationCacheExternal::TpuEntryRefImpl::get() {
if (entry_ == nullptr) {
// Create an empty entry if the entry is nullptr. This corresponds to
// non-existing sharding/unsharding entries.
return TpuCompilationCacheEntry();
}
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
}
Status TpuCompilationCacheExternal::TpuEntryRefImpl::ToSubEntryRef(
CompilationCacheFetchTarget fetch_target) {
CompiledSubgraph* target = nullptr;
switch (fetch_target) {
case CompilationCacheFetchTarget::MAIN:
target = entry_;
break;
case CompilationCacheFetchTarget::SHARDING:
target = entry_->sharding_entry.get();
break;
case CompilationCacheFetchTarget::UNSHARDING:
target = entry_->unsharding_entry.get();
break;
default:
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
}
if (target == nullptr) {
// Cache entry does not have an unsharding subentry. Unref and replace
// with nullptr.
parent_->DiscardEntryRefs({entry_});
}
// Otherwise, since the refcount is always on the main entry, we don't need
// ref/unref.
entry_ = target;
return Status::OK();
}
Status TpuCompilationCacheExternal::Lookup(
int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme(
"TPU compilation cache proto lookup by uid",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_uid_.find(uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No subgraph found for uid ", uid);
}
CompiledSubgraph* cache_entry = iter->second;
if (proto_index < 0 ||
proto_index >= cache_entry->tpu_program_group->program_size()) {
return errors::NotFound("No proto found for core index ", proto_index,
" in subgraph with uid ", uid);
}
*entry = std::unique_ptr<CompilationCacheEntryRef>(
new TpuEntryRefImpl(this, cache_entry, proto_index));
return Status::OK();
}
Status TpuCompilationCacheExternal::Lookup(
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_proto_key_.find(proto_key);
if (iter == entries_by_proto_key_.end()) {
return errors::NotFound("No proto found for key ", proto_key);
}
CompiledSubgraph* cache_entry = iter->second.first;
int proto_index = iter->second.second;
*entry = std::unique_ptr<CompilationCacheEntryRef>(
new TpuEntryRefImpl(this, cache_entry, proto_index));
return Status::OK();
}
Status TpuCompilationCacheExternal::CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<Status(TpuProgramGroup*)>& compile_function) {
profiler::TraceMe subgraph_lookup_traceme(
"TPU compilation cache subgraph lookup",
/*level=*/2);
// NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
// for the lifetime of the object, see InitializeEntry() call below.
absl::MutexLock lock(&mu_);
std::string cache_key = FindCacheKey(subgraph_key);
auto iter = cache_store_.find(cache_key);
bool is_new_key = iter == cache_store_.end();
const std::string session_name = SessionNameFromMetadata(session_metadata);
CompiledSubgraph* entry = nullptr;
if (is_new_key) {
cache_key = ConstructCompilationCacheKey(subgraph_key);
TpuCompilationCacheMetrics::IncrementCacheLookupCount(
/*is_cache_hit=*/false, session_name);
const string msg =
strings::StrCat("TPU host compilation cache miss: cache_key(",
cache_key, "), session_name(", session_name, ")");
TRACESTRING(msg);
LOG(INFO) << msg;
// Check if caller has disabled compilation. Set using
// internal::ScopedTpuCompileDisabler.
if (!IsTpuCompilationEnabled()) {
const string error_msg = strings::StrCat(
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
"disabled, session_name(",
session_name, ") Debug String: ", subgraph_key.debug_string);
if (VLOG_IS_ON(2)) {
VLOG(2) << "Cache Missed. Current cache entries: ";
for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
// TODO(henrytan): add DebugKey as cache_entry_debug_string to
// TpuCompilationCacheKey.
VLOG(2) << "Cache Debug Info: ";
VLOG(2) << it->second->cache_entry_debug_string;
}
}
LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
return errors::NotFound(error_msg);
}
// The single ref on the newly-created entry is owned by the caller.
VLOG(1) << "Before adding new entry for key " << cache_key
<< " with session_name( " << session_name << ");"
<< "; cache is " << cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
<< " marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
// Note that InitializeEntry() will Release/Reacquire mu_.
entry = InitializeEntry(cache_key, compile_function, subgraph_key);
TRACELITERAL("TPU host compilation cache: compilation done.");
LOG(INFO) << strings::StrCat(
"TPU host compilation cache: compilation done for cache_key(",
cache_key, "), session_name(", session_name, ")");
// If session_name is present, log some additional stats related to HBM
// here, so that they can be associated directly to the session.
if (!session_name.empty()) {
entry->tpu_program_group->LogProgramMemorySummary();
}
} else {
TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
const string msg =
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
"), session_name(", session_name, ")");
TRACESTRING(msg);
VLOG(1) << msg;
VLOG(1) << "Before refreshing entry for key " << cache_key
<< " with session_name( " << session_name << "); cache is "
<< cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
<< " marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
entry = iter->second;
// Make a new reference that is owned by the caller.
entry->Ref();
// Block if necessary until the subgraph has been initialized.
mu_.Await(absl::Condition(
+[](CompiledSubgraph* e) { return e->initialized; }, entry));
}
// Let the caller know the uid of the entry.
*uid = entry->uid;
// Let the caller know the keys for each of the cached protos.
*proto_key = entry->proto_key;
*may_modify_variables = entry->tpu_program_group->may_modify_variables();
*hlo_metadata = entry->tpu_program_group->hlo_metadatas();
// If the caller didn't supply a per_step_ref_holder then the caller is going
// to manually release the reference later via a call to Release().
if (per_step_ref_holder == nullptr) {
++entry->external_references;
} else {
// The caller wants its reference to be handed off to a per-step holder that
// will discard the reference when the step completes.
RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
TF_RET_CHECK(cast_ref_holder != nullptr);
cast_ref_holder->AddRef(entry);
}
// Remove the old LRU-table entry if it wasn't already marked for eviction.
auto erased = entries_by_last_use_.erase(entry->last_use);
// Update the LRU table indicating this entry is the most recently used.
entry->last_use = use_counter_++;
entries_by_last_use_[entry->last_use] = entry;
if (erased == 0) {
// The entry had been marked for eviction, or is newly created.
LookupEntryMarkedForEviction(entry, removed_entries);
}
// Log a little more verbosely when a key is added.
if (VLOG_IS_ON(1) || is_new_key) {
LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
<< " entry for key " << cache_key << " with session_name "
<< session_name << " cache is " << cache_store_.size()
<< " entries (" << cache_size_ + marked_for_eviction_size_
<< " bytes), "
<< " marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
return entry->initialization_status;
}
tensorflow::Status TpuCompilationCacheExternal::CompileIfKeyAbsent(
const TpuCompilationCacheKey& cache_key,
const tensorflow::SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<tensorflow::Status(TpuProgramGroup*)>&
compile_function) {
std::vector<CompiledSubgraph*> removed_entries;
auto status = CompileIfKeyAbsentHelper(
cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
may_modify_variables, &removed_entries, hlo_metadata, compile_function);
for (auto entry : removed_entries) {
UnloadAndDestroy(entry);
}
return status;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -26,14 +26,11 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
@ -43,25 +40,37 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
constexpr char kCompilationCacheResourceName[] = "tpu_compilation_cache";
constexpr char kCompilationCacheUnloaderResourceName[] =
const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
const char kCompilationCacheUnloaderResourceName[] =
"tpu_compilation_cache_unloader";
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
// Base class that holds references to compiled protos so that the protos are
// not garbage-collected before being used by execute ops. Use
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
// ref holder object.
class TpuCompilationRefHolder : public ResourceBase {
public:
~TpuCompilationRefHolder() override = default;
};
class TpuCompilationCacheExternal : public ResourceBase {
public:
using Status = ::stream_executor::port::Status;
class EntryRefImpl
: public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
public:
EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
int index);
explicit TpuCompilationCacheExternal(int64_t max_cache_size);
~TpuCompilationCacheExternal() override;
TpuCompilationCacheExternal(const TpuCompilationCacheExternal&) = delete;
TpuCompilationCacheExternal& operator=(const TpuCompilationCacheExternal&) =
delete;
TpuCompilationCacheEntry get() override;
};
explicit TpuCompilationCacheExternal(int64 max_cache_size)
: TpuCompilationCacheInterface(max_cache_size) {}
Status CompileIfKeyAbsent(
const TpuCompilationCacheKey& cache_key,
const SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<tensorflow::Status(TpuProgramGroup*)>&
compile_function);
static TpuCompilationCacheKey CreateCompilationCacheKey(
absl::string_view function_name, uint64 function_library_fingerprint,
@ -73,7 +82,177 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
string DebugString() const override { return "TpuCompilationCacheExternal"; }
// Makes a reference holder for this cache, that can be stored in the per-step
// resource manager and will ensure that compiled entries persist until the
// end of a step.
TpuCompilationRefHolder* MakePerStepRefHolder();
// Differences between MarkEntryForEviction and Release:
// There are two modes of managing cache entries:
// 1) LRU eviction + pinning; 2) manual.
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
// Otherwise it is manual mode (mainly used by XRT).
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
// entries when callers know that they do not need them anymore.
// Release should only be used in mode 2) to explicitly remove an entry.
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
// subgraph_uid).
Status MarkEntryForEviction(int64 subgraph_uid);
// Manually discards a reference to the compiled subgraph. This should only be
// called if per_step_ref_holder was nullptr in the corresponding call to
// CompileIfKeyAbsent(subgraph_key, ...).
Status Release(int64 subgraph_uid);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by key. On success a pointer to an EntryRef
// holding the program is returned in entry.
Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by uid. On success a pointer to an EntryRef
// holding the program is returned in entry.
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry);
// Mutates the main entry ref to point to the entry's subentry
// (for sharding/unsharding) or main entry (unchanged) representing the
// fetch target. The entry ref needs to point to the main entry before this
// call.
//
// If the requested subentry does not exist, the ref will point to a nullptr
// entry.
Status ToSubEntryRef(CompilationCacheEntryRef* entry,
CompilationCacheFetchTarget fetch_target) const;
private:
// Wrapper for a cache entry that holds a reference to the entry until the
// wrapper is deleted. This wrapper is the concrete type of
// CompilationCacheEntryRef returned by Lookup.
class TpuEntryRefImpl : public CompilationCacheEntryRef {
public:
TpuEntryRefImpl(TpuCompilationCacheExternal* parent,
CompiledSubgraph* entry, int index);
~TpuEntryRefImpl() override;
TpuCompilationCacheEntry get() override;
// Mutates this ref to point to the entry's subentry (for
// sharding/unsharding) or main entry (unchanged) as specified by
// fetch_target. The refcount is kept unchanged, since we only track the
// refcount of the main entry. The entry ref needs to point to the main
// entry before this call.
//
// If the requested subentry does not exist, the ref will point to a nullptr
// entry, and the original entry will be unref'ed.
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
private:
TpuCompilationCacheExternal* parent_; // Not owned.
// A reference to entry_ is acquired in the constructor and released via
// parent->DiscardEntryRefs in the destructor.
CompiledSubgraph* entry_;
// The program in entry_ that is returned by the get method.
int index_;
};
// Private implementation of the generic CompilationRefHolder that knows about
// CompiledSubgraph entries.
class RefHolder : public TpuCompilationRefHolder {
public:
explicit RefHolder(TpuCompilationCacheExternal* parent) : parent_(parent) {
parent_->Ref();
}
~RefHolder() override {
// Release our reference to the parent.
parent_->Unref();
}
// Adds entry to the list of entries that will be released when the
// RefHolder is destroyed. Each entry is released via a call to
// parent_->DiscardEntryRefs.
void AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); }
string DebugString() const override {
return "TpuCompilationCacheExternal::RefHolder";
}
private:
TpuCompilationCacheExternal* parent_; // Not owned.
std::vector<CompiledSubgraph*> entries_;
};
// The bulk of implementation of CompileIfKeyAbsent() with the exception
// of unloading programs that corresponds to possibly removed cache
// entries. The split helps to manage locking since we prefer to perform
// unloading without holding extra locks.
Status CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<Status(TpuProgramGroup*)>& compile_function);
// This is called by the cache when entry is marked for eviction; by
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
// an EntryRefImpl when it is destroyed. Releases one reference to entry
// if more than 1 remains. If only one reference is left, the entry is removed
// from cache_ and is returned to the caller; which must eventually call
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
// to avoid holding the lock during program unloading.
ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Convenience method called by ~RefHolder without mu_ held. Calls
// DiscardEntryRef on every element of entries.
void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
// Marks the oldest unmarked entry for eviction. Requires that there is at
// least one such entry. In case the evicted entry had only 1 reference it
// is removed from the cache and returned to the caller which must eventually
// call UnloadAndDestroy.
CompiledSubgraph* MarkOldestEntryForEviction()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Updates datastructures to indicate that entry, which had been marked for
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
// entry is newly created, or an entry that has been marked for eviction but
// not yet evicted is looked up.
//
// First the entry is unmarked for eviction, i.e. the cache gains a reference
// to entry, entry's last_use field is set to be the most recent value of
// use_counter_ and entries_by_last_use_ is updated accordingly.
//
// Next, the size of the cache is examined to see if any other entries need to
// be marked for eviction now that entry has been unmarked. While the total
// size of unmarked cached entries is greater than max_cache_size_, entries
// are marked for eviction in LRU order. The most recently used entry is never
// marked for eviction, so an entry larger than the max cache size will remain
// in the cache until it is replaced by something else. In case some entries
// actually were removed from the cache, they are a returned to the caller via
// removed_entries. The caller must eventually delete them by calling
// UnloadAndDestroy.
void LookupEntryMarkedForEviction(
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Removes the entry with given key from cache.
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Inserts the given key and entry to cache.
void InsertEntry(const std::string& key,
const TpuCompilationCacheKey& subgraph_key,
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the cache key matching given subgraph_key.
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Creates a new entry by running initialize_programs and places it in the
// cache to be looked up by key. The new entry is in the 'marked for eviction'
// state (not present in entries_by_last_use_) and the caller is expected to
@ -82,10 +261,61 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
// **InitializeEntry releases mu_ during the call to initialize_programs.**
CompiledSubgraph* InitializeEntry(
const string& key,
const std::function<Status(TpuProgramGroupInterface*)>&
initialize_program,
const std::function<Status(TpuProgramGroup*)>& initialize_program,
const TpuCompilationCacheKey& subgraph_key)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override;
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Unloads the program associated with the entry from all local devices
// and deletes the entry itself. It is assumed no one else has a reference
// to it and all related keys had already been removed from the cache.
// The call can perform device IO so no locks should be held while calling it.
void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
// The maximum size of entries that are stored in the cache before entries are
// marked for eviction.
const int64 max_cache_size_;
mutable absl::Mutex mu_;
// The total size of entries that are stored and not marked for eviction.
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
// The total size of entries that are marked for eviction.
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
// The value to assign to the last_use field of the next entry that is looked
// up.
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
// session_key_map_ and fingerprint_key_map_ are used for looking up the
// cache_ key matching a given subgraph key. When doing a lookup, check
// session_key_map_ first to avoid unnecessay fingerprint computation.
// Map from key prefix + session_handle to a cache_ key.
std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
// Map from key prefix + fingerprint to a cache_ key.
std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache. An entry is
// marked for eviction iff it is present in cache_ and not in
// entries_by_last_use_.
std::unordered_map<string, CompiledSubgraph*> cache_store_
ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache, indexed by
// uid.
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
ABSL_GUARDED_BY(mu_);
// All the protos that can be looked up in the cache, indexed by proto
// key. The value of the map is a subgraph and the index of the proto compiled
// for that subgraph.
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
// Map from last_use to entry, used to mark entries for eviction in LRU
// order. If an entry's last_use counter is not present as a key in
// entries_by_last_use_ then the entry has been marked for eviction.
std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
};
} // namespace tpu

View File

@ -1,355 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/node_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
#include "tensorflow/core/tpu/kernels/trace_util.h"
namespace tensorflow {
namespace tpu {
// Base class that holds references to compiled protos so that the protos are
// not garbage-collected before being used by execute ops. Use
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
// ref holder object.
class CompilationRefHolder : public ResourceBase {
public:
~CompilationRefHolder() override = default;
};
// Base class for a reference to a cached tpu program. A unique_ptr to a
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
// and ensures the underlying proto is not garbage-collected until the client
// discards the ptr.
template <typename CacheEntryType>
class CompilationCacheEntryRef {
public:
virtual ~CompilationCacheEntryRef() = default;
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
// of the tpu::CompilationCacheEntryRef.
virtual CacheEntryType get() = 0;
// Mutates this ref to point to the entry's subentry (for
// sharding/unsharding) or main entry (unchanged) as specified by
// fetch_target. The refcount is kept unchanged, since we only track the
// refcount of the main entry. The entry ref needs to point to the main
// entry before this call.
//
// If the requested subentry does not exist, the ref will point to a nullptr
// entry, and the original entry will be unref'ed.
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
};
class TpuCompilationCacheInterface : public ResourceBase {
public:
explicit TpuCompilationCacheInterface(int64 max_cache_size);
~TpuCompilationCacheInterface() override;
// Ensures there is an entry for key present in the cache. By the time
// CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
// for key, and that entry will remain valid at least until
// per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a
// key that is not in the cache will evaluate compile_function to compute the
// value to use in the entry. Subsequent calls with the same key will block
// until compile_function completes. Other cache reads and inserts may proceed
// on other threads while compile_function is executing. If
// per_step_ref_holder is nullptr then the caller is responsible for calling
// Release(subgraph_key) to manually discard its reference to the compiled
// program, once the caller will not look up the compiled program again.
//
// compile_function should compile the subgraph represented by key and fill in
// one TPUExecutableProto per model-parallel core into its passed argument. It
// should return OK if and only if compilation succeeds. The executable proto
// vector will be discarded on non-OK status.
Status CompileIfKeyAbsent(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
// Differences between MarkEntryForEviction and Release:
// There are two modes of managing cache entries:
// 1) LRU eviction + pinning; 2) manual.
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
// Otherwise it is manual mode (mainly used by XRT).
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
// entries when callers know that they do not need them anymore.
// Release should only be used in mode 2) to explicitly remove an entry.
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
// subgraph_uid).
Status MarkEntryForEviction(int64 subgraph_uid);
// Manually discards a reference to the compiled subgraph. This should only be
// called if per_step_ref_holder was nullptr in the corresponding call to
// CompileIfKeyAbsent(subgraph_key, ...).
Status Release(int64 subgraph_uid);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by key. On success a pointer to an EntryRef
// holding the program is returned in entry.
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by uid. On success a pointer to an EntryRef
// holding the program is returned in entry.
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CacheEntryRef>* entry);
// Looks up the subgraph represented by uid, and returns the vector of keys,
// one per core, corresponding to that subgraph.
Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
// Makes a reference holder for this cache, that can be stored in the per-step
// resource manager and will ensure that compiled entries persist until the
// end of a step.
CompilationRefHolder* MakePerStepRefHolder();
// Convenience method called by ~RefHolder without mu_ held. Calls
// DiscardEntryRef on every element of entries.
void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
string DebugString() const override { return "TpuCompilationCacheBase"; }
protected:
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
if (!key.has_guaranteed_const) {
return key.prefix;
}
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
key.guaranteed_const_fingerprint());
}
// Private implementation of the generic CompilationRefHolder that knows about
// CompiledSubgraph entries.
class RefHolder : public CompilationRefHolder {
public:
explicit RefHolder(TpuCompilationCacheInterface* parent);
~RefHolder() override;
// Adds entry to the list of entries that will be released when the
// RefHolder is destroyed. Each entry is released via a call to
// parent_->DiscardEntryRefs.
void AddRef(CompiledSubgraph* entry);
string DebugString() const override;
private:
TpuCompilationCacheInterface* parent_; // Not owned.
std::vector<CompiledSubgraph*> entries_;
};
// The bulk of implementation of CompileIfKeyAbsent() with the exception
// of unloading programs that corresponds to possibly removed cache
// entries. The split helps to manage locking since we prefer to perform
// unloading without holding extra locks.
Status CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
// This is called by the cache when entry is marked for eviction; by
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
// an EntryRefImpl when it is destroyed. Releases one reference to entry
// if more than 1 remains. If only one reference is left, the entry is removed
// from cache_ and is returned to the caller; which must eventually call
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
// to avoid holding the lock during program unloading.
ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Marks the oldest unmarked entry for eviction. Requires that there is at
// least one such entry. In case the evicted entry had only 1 reference it
// is removed from the cache and returned to the caller which must eventually
// call UnloadAndDestroy.
ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Updates datastructures to indicate that entry, which had been marked for
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
// entry is newly created, or an entry that has been marked for eviction but
// not yet evicted is looked up.
//
// First the entry is unmarked for eviction, i.e. the cache gains a reference
// to entry, entry's last_use field is set to be the most recent value of
// use_counter_ and entries_by_last_use_ is updated accordingly.
//
// Next, the size of the cache is examined to see if any other entries need to
// be marked for eviction now that entry has been unmarked. While the total
// size of unmarked cached entries is greater than max_cache_size_, entries
// are marked for eviction in LRU order. The most recently used entry is never
// marked for eviction, so an entry larger than the max cache size will remain
// in the cache until it is replaced by something else. In case some entries
// actually were removed from the cache, they are a returned to the caller via
// removed_entries. The caller must eventually delete them by calling
// UnloadAndDestroy.
void LookupEntryMarkedForEviction(
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Removes the entry with given key from cache.
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Inserts the given key and entry to cache.
void InsertEntry(const string& key, CompiledSubgraph* entry)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the cache key matching given subgraph_key.
string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Creates a new entry by running initialize_programs and places it in the
// cache to be looked up by key. The new entry is in the 'marked for eviction'
// state (not present in entries_by_last_use_) and the caller is expected to
// call LookupEntryMarkedForEviction after InitializeEntry.
//
// **InitializeEntry releases mu_ during the call to initialize_programs.**
virtual CompiledSubgraph* InitializeEntry(
const string& key,
const std::function<Status(TpuProgramGroupInterface*)>&
initialize_programs,
const TpuCompilationCacheKey& subgraph_key)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
// Unloads the program associated with the entry from all local devices
// and deletes the entry itself. It is assumed no one else has a reference
// to it and all related keys had already been removed from the cache.
// The call can perform device IO so no locks should be held while calling it.
void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
// The maximum size of entries that are stored in the cache before entries are
// marked for eviction.
const int64 max_cache_size_;
// Mutex to protect access to shared resources under multi-threading
// environment.
absl::Mutex mu_;
// The total size of entries that are stored and not marked for eviction.
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
// The total size of entries that are marked for eviction.
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
// The value to assign to the last_use field of the next entry that is looked
// up.
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
// session_key_map_ and fingerprint_key_map_ are used for looking up the
// cache_ key matching a given subgraph key. When doing a lookup, check
// session_key_map_ first to avoid unnecessay fingerprint computation.
// Map from key prefix + session_handle to a cache_ key.
absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
// Map from key prefix + fingerprint to a cache_ key.
absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache. An entry is
// marked for eviction iff it is present in cache_ and not in
// entries_by_last_use_.
std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache, indexed by
// uid.
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
ABSL_GUARDED_BY(mu_);
// All the protos that can be looked up in the cache, indexed by proto
// key. The value of the map is a subgraph and the index of the proto compiled
// for that subgraph.
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
// Map from last_use to entry, used to mark entries for eviction in LRU
// order. If an entry's last_use counter is not present as a key in
// entries_by_last_use_ then the entry has been marked for eviction.
std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
TpuCompilationCacheMetrics tpu_compilation_cache_metrics_;
private:
TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
delete;
};
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status TpuCompilationCacheInterface::Lookup(
int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme(
"TPU compilation cache proto lookup by uid",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_uid_.find(uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No subgraph found for uid ", uid);
}
CompiledSubgraph* cache_entry = iter->second;
if (proto_index < 0 ||
proto_index >= cache_entry->tpu_program_group->program_count()) {
return errors::NotFound("No proto found for core index ", proto_index,
" in subgraph with uid ", uid);
}
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
return Status::OK();
}
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status TpuCompilationCacheInterface::Lookup(
const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_proto_key_.find(proto_key);
if (iter == entries_by_proto_key_.end()) {
return errors::NotFound("No proto found for key ", proto_key);
}
CompiledSubgraph* cache_entry = iter->second.first;
int proto_index = iter->second.second;
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_

View File

@ -42,7 +42,7 @@ std::string GetName(CompilationCacheFetchTarget target) {
} // namespace
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
TpuCompilationCacheInterface* cache)
TpuCompilationCacheExternal* cache)
: cache_(cache) {}
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
@ -50,19 +50,17 @@ TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
}
Status TpuCompilationCacheLocalLookup::Lookup(
const string& proto_key,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
/*level=*/2);
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
proto_key, entry);
Status s = cache_->Lookup(proto_key, entry);
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
<< s;
if (!s.ok()) {
return s;
}
s = (*entry)->ToSubEntryRef(fetch_target);
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
<< s;
@ -71,18 +69,17 @@ Status TpuCompilationCacheLocalLookup::Lookup(
Status TpuCompilationCacheLocalLookup::Lookup(
int64 uid, int proto_index,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
/*level=*/2);
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
uid, proto_index, entry);
Status s = cache_->Lookup(uid, proto_index, entry);
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
<< " in local subgraph cache status " << s;
if (!s.ok()) {
return s;
}
s = (*entry)->ToSubEntryRef(fetch_target);
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
<< s;
return s;

View File

@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
namespace tensorflow {
namespace tpu {
@ -30,11 +28,6 @@ namespace tpu {
// and when they need to communicate over RPC.
class TpuCompilationCacheLookup : public ResourceBase {
public:
using TpuCompilationCacheEntryRef =
::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
using EntryRefImpl =
::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
~TpuCompilationCacheLookup() override = default;
// Looks up an executable corresponding to the model-parallel core index of
@ -49,11 +42,11 @@ class TpuCompilationCacheLookup : public ResourceBase {
// fetch_target requests one of them, then after this call
// (*entry)->get().get_executable() will return nullptr.
virtual Status Lookup(const string& proto_key,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) = 0;
virtual Status Lookup(const string& proto_key,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
std::unique_ptr<CompilationCacheEntryRef>* entry) {
return Lookup(proto_key, std::move(entry),
CompilationCacheFetchTarget::MAIN);
}
@ -63,30 +56,33 @@ class TpuCompilationCacheLookup : public ResourceBase {
// returned in program. The wrapper is guaranteed to be valid only during the
// execution of the Op requesting the proto.
virtual Status Lookup(int64 uid, int proto_index,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) = 0;
virtual Status Lookup(int64 uid, int proto_index,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
std::unique_ptr<CompilationCacheEntryRef>* entry) {
return Lookup(uid, proto_index, std::move(entry),
CompilationCacheFetchTarget::MAIN);
}
};
// Forward declaration to break cycle dependency graph.
class TpuCompilationCacheExternal;
// Class for looking up ISA protos when the execute and compile Op are in the
// same address space. The proto is simply looked up in the compilation cache,
// without any serialization taking place.
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
public:
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheExternal* cache);
~TpuCompilationCacheLocalLookup() override;
Status Lookup(const string& proto_key,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) override;
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) override;
string DebugString() const override;
@ -94,10 +90,10 @@ class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
private:
// The subgraph compilation cache, in the same process address space where the
// lookups are happening.
TpuCompilationCacheInterface* cache_;
TpuCompilationCacheExternal* cache_;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"