TPU library internal change.
PiperOrigin-RevId: 316171774 Change-Id: I85b3c2639c734e391692f568bffae0efe116e9af
This commit is contained in:
parent
9130ba73e6
commit
1e3eddd152
|
@ -19,9 +19,9 @@ cc_library(
|
|||
deps = [
|
||||
":tpu_compile_op_support",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_util",
|
||||
":tpu_util_hdrs",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
|
@ -30,16 +30,16 @@ cc_library(
|
|||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
# "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -157,28 +157,14 @@ cc_library(
|
|||
"tpu_compilation_cache_entry.h",
|
||||
],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_executable_info_proto_cc",
|
||||
":tpu_program_group",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/platform:casts",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_entry_impl",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_compilation_cache_entry_impl.h"],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_executable_info_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_lookup",
|
||||
srcs = ["tpu_compilation_cache_lookup.cc"],
|
||||
|
@ -188,7 +174,6 @@ cc_library(
|
|||
deps = [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_external",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
|
@ -262,35 +247,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_interface",
|
||||
srcs = ["tpu_compilation_cache_interface.cc"],
|
||||
hdrs = ["tpu_compilation_cache_interface.h"],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compilation_cache_metrics_hdrs",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_util",
|
||||
":tpu_util_hdrs",
|
||||
":trace_util_hdrs",
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_external",
|
||||
srcs = ["tpu_compilation_cache_external.cc"],
|
||||
|
@ -300,8 +256,6 @@ cc_library(
|
|||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_entry_impl",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compilation_cache_metrics", # buildcleaner: keep
|
||||
":tpu_compilation_cache_metrics_hdrs",
|
||||
|
@ -401,7 +355,6 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
|
|
@ -25,9 +25,6 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
class TpuCompilationCacheInterface;
|
||||
|
||||
// Cache for compiled TPU program.
|
||||
//
|
||||
// Each key identifies a unique subgraph, and the value is the vector of
|
||||
|
@ -103,7 +100,10 @@ class TpuCompilationCacheInterface;
|
|||
// unmarked and set to most recently used.
|
||||
//
|
||||
struct CompiledSubgraph : public core::RefCounted {
|
||||
TpuCompilationCacheInterface* parent = nullptr; // Not owned.
|
||||
// TODO(henrytan): once `TpuCompilationCache` and
|
||||
// `TpuCompilationCacheExternal` inherits from `TpuCompilationCacheInterface`
|
||||
// update void* with `TpuCompilationCacheInterface`
|
||||
void* parent = nullptr; // Not owned.
|
||||
|
||||
bool initialized = false;
|
||||
|
||||
|
@ -145,7 +145,7 @@ struct CompiledSubgraph : public core::RefCounted {
|
|||
// owning main entry.
|
||||
CompiledSubgraph* main_entry = nullptr;
|
||||
|
||||
// Compiled TPU program group.
|
||||
// Compiled Tpu program.
|
||||
std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
|
||||
|
||||
// Computes total program size.
|
||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// A version of `CompilationCacheEntry` to access Tpu binary program
|
||||
// A version of `CompilationCacheEntry` that exposes Tpu binary program
|
||||
// `XLA_TpuProgram`.
|
||||
class TpuCompilationCacheEntry {
|
||||
public:
|
||||
|
@ -42,6 +42,28 @@ class TpuCompilationCacheEntry {
|
|||
int core_index_;
|
||||
};
|
||||
|
||||
// Base class for a reference to a cached proto. A unique_ptr to a
|
||||
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
||||
// and ensures the underlying proto is not garbage-collected until the client
|
||||
// discards the ptr.
|
||||
class CompilationCacheEntryRef {
|
||||
public:
|
||||
virtual ~CompilationCacheEntryRef() = default;
|
||||
|
||||
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
||||
// of the CompilationCacheEntryRef.
|
||||
virtual TpuCompilationCacheEntry get() = 0;
|
||||
};
|
||||
|
||||
// Base class that holds references to compiled protos so that the protos are
|
||||
// not garbage-collected before being used by execute ops. Use
|
||||
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||
// ref holder object.
|
||||
class CompilationRefHolder : public ResourceBase {
|
||||
public:
|
||||
~CompilationRefHolder() override = default;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Wrapper for a cache entry that holds a reference to the entry until the
|
||||
// wrapper is deleted. This wrapper is the concrete type of
|
||||
// CompilationCacheEntryRef returned by Lookup.
|
||||
template <typename CacheEntryType>
|
||||
class CompilationCacheEntryRefImpl
|
||||
: public CompilationCacheEntryRef<CacheEntryType> {
|
||||
public:
|
||||
CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
|
||||
CompiledSubgraph* entry, int index);
|
||||
|
||||
~CompilationCacheEntryRefImpl() override;
|
||||
|
||||
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
protected:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
// A reference to entry_ is acquired in the constructor and released via
|
||||
// parent->DiscardEntryRefs in the destructor.
|
||||
CompiledSubgraph* entry_;
|
||||
// The index of the program in entry_ that is returned by the get method.
|
||||
int index_;
|
||||
};
|
||||
|
||||
template <typename CacheEntryType>
|
||||
CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
|
||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||
: parent_(parent), entry_(entry), index_(index) {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
entry_->Ref();
|
||||
} else {
|
||||
// This is a sharding/unsharding entry nested in a main entry. Only
|
||||
// refcount the main entry.
|
||||
entry_->main_entry->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CacheEntryType>
|
||||
CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
} else {
|
||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CacheEntryType>
|
||||
Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
CompiledSubgraph* target = nullptr;
|
||||
switch (fetch_target) {
|
||||
case CompilationCacheFetchTarget::MAIN:
|
||||
target = entry_;
|
||||
break;
|
||||
case CompilationCacheFetchTarget::SHARDING:
|
||||
target = entry_->sharding_entry.get();
|
||||
break;
|
||||
case CompilationCacheFetchTarget::UNSHARDING:
|
||||
target = entry_->unsharding_entry.get();
|
||||
break;
|
||||
default:
|
||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||
}
|
||||
|
||||
if (target == nullptr) {
|
||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||
// with nullptr.
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
}
|
||||
// Otherwise, since the refcount is always on the main entry, we don't
|
||||
// need ref/unref.
|
||||
entry_ = target;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
|
@ -50,6 +50,14 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
|
|||
entry->initialized = true;
|
||||
}
|
||||
|
||||
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
|
||||
if (!key.has_guaranteed_const) {
|
||||
return key.prefix;
|
||||
}
|
||||
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
|
||||
key.guaranteed_const_fingerprint());
|
||||
}
|
||||
|
||||
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
||||
// data to compute the fingerprint.
|
||||
std::string GuaranteedConstFingerprint(
|
||||
|
@ -115,32 +123,85 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
|
|||
|
||||
} // namespace
|
||||
|
||||
TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
|
||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||
: CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
|
||||
index) {}
|
||||
|
||||
TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
|
||||
if (entry_ == nullptr) {
|
||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||
// non-existing sharding/unsharding entries.
|
||||
return TpuCompilationCacheEntry();
|
||||
TpuCompilationCacheExternal::TpuCompilationCacheExternal(int64_t max_cache_size)
|
||||
: max_cache_size_(max_cache_size) {
|
||||
if (max_cache_size < 0) {
|
||||
LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
|
||||
}
|
||||
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
||||
VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
|
||||
}
|
||||
|
||||
TpuCompilationCacheExternal::~TpuCompilationCacheExternal() {
|
||||
VLOG(1) << "TpuCompilationCacheExternal::~TpuCompilationCacheExternal()";
|
||||
// A buggy client may be holding onto a reference, or a client might have
|
||||
// crashed while holding onto a reference. In either case, discard all
|
||||
// outstanding client references to avoid leaking storage.
|
||||
for (const auto& entry : entries_by_uid_) {
|
||||
while (entry.second->external_references > 0) {
|
||||
TF_CHECK_OK(Release(entry.first));
|
||||
}
|
||||
}
|
||||
while (!entries_by_last_use_.empty()) {
|
||||
UnloadAndDestroy(MarkOldestEntryForEviction());
|
||||
}
|
||||
// By the time the cache is deleted all reference holders should have already
|
||||
// been deleted, since they were holding references to the cache. So all
|
||||
// entries should be gone at this point.
|
||||
CHECK_EQ(cache_store_.size(), 0);
|
||||
CHECK_EQ(entries_by_uid_.size(), 0);
|
||||
CHECK_EQ(entries_by_proto_key_.size(), 0);
|
||||
CHECK_EQ(cache_size_, 0);
|
||||
CHECK_EQ(marked_for_eviction_size_, 0);
|
||||
}
|
||||
|
||||
std::string TpuCompilationCacheExternal::FindCacheKey(
|
||||
const TpuCompilationCacheKey& subgraph_key) const {
|
||||
if (!subgraph_key.has_guaranteed_const) {
|
||||
return subgraph_key.prefix;
|
||||
}
|
||||
auto iter = session_key_map_.find(
|
||||
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
|
||||
if (iter != session_key_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
iter = fingerprint_key_map_.find(strings::StrCat(
|
||||
subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
|
||||
if (iter != session_key_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
VLOG(1) << "No matching cache key found for key "
|
||||
<< ConstructCompilationCacheKey(subgraph_key);
|
||||
return "";
|
||||
}
|
||||
|
||||
void TpuCompilationCacheExternal::InsertEntry(
|
||||
const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
|
||||
CompiledSubgraph* entry) {
|
||||
entry->parent = this;
|
||||
entry->subgraph_key = cache_key;
|
||||
entry->uid = get_uid();
|
||||
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
|
||||
entry->cache_entry_debug_string = subgraph_key.prefix;
|
||||
VLOG(1) << "Cache Initializing Entry Session Debug "
|
||||
<< entry->cache_entry_debug_string;
|
||||
|
||||
if (!subgraph_key.has_guaranteed_const) {
|
||||
return;
|
||||
}
|
||||
session_key_map_.insert(std::make_pair(
|
||||
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
|
||||
cache_key));
|
||||
fingerprint_key_map_.insert(std::make_pair(
|
||||
strings::StrCat(subgraph_key.prefix,
|
||||
subgraph_key.guaranteed_const_fingerprint()),
|
||||
cache_key));
|
||||
}
|
||||
|
||||
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
|
||||
const std::function<Status(TpuProgramGroup*)>& initialize_program,
|
||||
const TpuCompilationCacheKey& subgraph_key) {
|
||||
CompiledSubgraph* main_entry = new CompiledSubgraph();
|
||||
main_entry->parent = this;
|
||||
main_entry->subgraph_key = key;
|
||||
main_entry->uid = get_uid();
|
||||
// TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
|
||||
main_entry->cache_entry_debug_string = subgraph_key.prefix;
|
||||
VLOG(1) << "Cache Initializing Entry Session Debug "
|
||||
<< main_entry->cache_entry_debug_string;
|
||||
|
||||
// Add the entry to the cache, with size zero since there are no compiled
|
||||
// programs in it. Once the subgraph has been compiled,
|
||||
|
@ -151,7 +212,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
|||
// who created the entry. A second reference, owned by the cache, will be
|
||||
// added below since we leave the entry in the 'marked for eviction' state
|
||||
// here.
|
||||
InsertEntry(key, main_entry);
|
||||
InsertEntry(key, subgraph_key, main_entry);
|
||||
|
||||
// Initialize the programs outside the lock so that other cache operations
|
||||
// can proceed during the (potentially lengthy) initialization.
|
||||
|
@ -259,5 +320,470 @@ TpuCompilationCacheExternal::CreateCompilationCacheKey(
|
|||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
TpuCompilationRefHolder* TpuCompilationCacheExternal::MakePerStepRefHolder() {
|
||||
return new RefHolder(this);
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::MarkEntryForEviction(int64 subgraph_uid) {
|
||||
profiler::TraceMe key_release_traceme(
|
||||
"TPU compilation cache possibly evict uid",
|
||||
/*level=*/2);
|
||||
CompiledSubgraph* deleted_entry = nullptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = entries_by_uid_.find(subgraph_uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
// If already evicted, return ok.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Mark entry for eviction.
|
||||
CompiledSubgraph* subgraph_to_evict = iter->second;
|
||||
// If there are external references, should not use this API.
|
||||
if (subgraph_to_evict->external_references != 0) {
|
||||
return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
|
||||
" external_references greater than zero. Should "
|
||||
"use TpuCompilationCache::Release.");
|
||||
}
|
||||
|
||||
VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
|
||||
entries_by_last_use_.erase(subgraph_to_evict->last_use);
|
||||
cache_size_ -= subgraph_to_evict->total_size;
|
||||
marked_for_eviction_size_ += subgraph_to_evict->total_size;
|
||||
|
||||
// Evict if refcount exactly one, otherwise only discard cache's reference
|
||||
// to the entry while the actual eviction will happen when refholder's
|
||||
// references go away.
|
||||
deleted_entry = DiscardEntryRef(subgraph_to_evict);
|
||||
|
||||
VLOG(1) << "After possibly evicting entry " << subgraph_uid
|
||||
<< " refs cache is " << cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
|
||||
// Unload from device cache if entry is evicted from host cache.
|
||||
UnloadAndDestroy(deleted_entry);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::Release(int64 subgraph_uid) {
|
||||
profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
|
||||
/*level=*/2);
|
||||
|
||||
CompiledSubgraph* deleted_entry = nullptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = entries_by_uid_.find(subgraph_uid);
|
||||
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No cache entry found for uid ", subgraph_uid);
|
||||
}
|
||||
|
||||
CHECK_GT(iter->second->external_references, 0);
|
||||
--iter->second->external_references;
|
||||
|
||||
deleted_entry = DiscardEntryRef(iter->second);
|
||||
|
||||
VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
|
||||
<< cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
UnloadAndDestroy(deleted_entry);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TpuCompilationCacheExternal::UnloadAndDestroy(CompiledSubgraph* entry) {
|
||||
if (!entry) return;
|
||||
|
||||
CHECK(entry->RefCountIsOne());
|
||||
entry->tpu_program_group->UnloadAndDestroyPrograms();
|
||||
entry->Unref();
|
||||
}
|
||||
|
||||
size_t TpuCompilationCacheExternal::RemoveEntry(const string& key) {
|
||||
auto erased = cache_store_.erase(key);
|
||||
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
|
||||
auto parsed_key_or_status = ParseCompilationCacheKey(key);
|
||||
CHECK(parsed_key_or_status.status().ok());
|
||||
const TpuCompilationCacheKey parsed_key =
|
||||
parsed_key_or_status.ConsumeValueOrDie();
|
||||
if (!parsed_key.has_guaranteed_const) {
|
||||
return erased;
|
||||
}
|
||||
session_key_map_.erase(
|
||||
strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
|
||||
fingerprint_key_map_.erase(strings::StrCat(
|
||||
parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
|
||||
return erased;
|
||||
}
|
||||
|
||||
ABSL_MUST_USE_RESULT CompiledSubgraph*
|
||||
TpuCompilationCacheExternal::DiscardEntryRef(CompiledSubgraph* entry) {
|
||||
if (entry->RefCountIsOne()) {
|
||||
// The last reference to this entry is going away, so really delete it from
|
||||
// the cache in such a way that it can't be restored by being looked up
|
||||
// again.
|
||||
|
||||
// Sanity-check that it has been marked for eviction.
|
||||
CHECK(entries_by_last_use_.find(entry->last_use) ==
|
||||
entries_by_last_use_.end());
|
||||
// Update the counter tracking how much space is taken up by entries that
|
||||
// are marked for eviction.
|
||||
marked_for_eviction_size_ -= entry->total_size;
|
||||
|
||||
// Remove the entry from the cache.
|
||||
auto erased = RemoveEntry(entry->subgraph_key);
|
||||
|
||||
if (erased == 0) {
|
||||
LOG(FATAL) << "Tried to discard nonexistent cache entry";
|
||||
}
|
||||
erased = entries_by_uid_.erase(entry->uid);
|
||||
CHECK_EQ(erased, 1);
|
||||
for (const string& key : entry->proto_key) {
|
||||
erased = entries_by_proto_key_.erase(key);
|
||||
CHECK_EQ(erased, 1);
|
||||
}
|
||||
// The actual deletion will happen outside the lock in UnloadAndDestroy().
|
||||
return entry;
|
||||
}
|
||||
entry->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TpuCompilationCacheExternal::DiscardEntryRefs(
|
||||
gtl::ArraySlice<CompiledSubgraph*> entries) {
|
||||
std::vector<CompiledSubgraph*> removed_entries;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
for (auto entry : entries) {
|
||||
removed_entries.push_back(DiscardEntryRef(entry));
|
||||
}
|
||||
|
||||
VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
|
||||
<< " entries (" << cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
for (auto removed_entry : removed_entries) {
|
||||
UnloadAndDestroy(removed_entry);
|
||||
}
|
||||
}
|
||||
|
||||
ABSL_MUST_USE_RESULT CompiledSubgraph*
|
||||
TpuCompilationCacheExternal::MarkOldestEntryForEviction() {
|
||||
CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
|
||||
VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
|
||||
entries_by_last_use_.erase(entry_to_mark->last_use);
|
||||
cache_size_ -= entry_to_mark->total_size;
|
||||
marked_for_eviction_size_ += entry_to_mark->total_size;
|
||||
// Discard the cache's reference to entry. If steps are holding onto
|
||||
// references to entry it won't be deleted until the last step holding it
|
||||
// completes. It stays in the cache in the meantime and can be resurrected
|
||||
// by a call to CompileIfKeyAbsent if that occurs before the last reference
|
||||
// expires.
|
||||
return DiscardEntryRef(entry_to_mark);
|
||||
}
|
||||
|
||||
void TpuCompilationCacheExternal::LookupEntryMarkedForEviction(
|
||||
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
|
||||
// The entry was previously marked for eviction (or is newly created) so
|
||||
// unmark it. Add a reference (owned by the cache), update the cache size, and
|
||||
// mark something old for eviction if necessary.
|
||||
entry->Ref();
|
||||
marked_for_eviction_size_ -= entry->total_size;
|
||||
cache_size_ += entry->total_size;
|
||||
|
||||
// Mark the least-recently-used non-marked entry for eviction. Never mark the
|
||||
// most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
|
||||
// which means there's only one entry not already marked for eviction), so
|
||||
// that an entry persists in the cache even if it is larger than the allocated
|
||||
// cache size.
|
||||
while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
|
||||
if (auto entry_to_evict = MarkOldestEntryForEviction()) {
|
||||
removed_entries->push_back(entry_to_evict);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::ToSubEntryRef(
|
||||
CompilationCacheEntryRef* entry,
|
||||
CompilationCacheFetchTarget fetch_target) const {
|
||||
return static_cast<TpuEntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
|
||||
}
|
||||
|
||||
TpuCompilationCacheExternal::TpuEntryRefImpl::TpuEntryRefImpl(
|
||||
TpuCompilationCacheExternal* parent, CompiledSubgraph* entry, int index)
|
||||
: parent_(parent), entry_(entry), index_(index) {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
entry_->Ref();
|
||||
} else {
|
||||
// This is a sharding/unsharding entry nested in a main entry. Only refcount
|
||||
// the main entry.
|
||||
entry_->main_entry->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
TpuCompilationCacheExternal::TpuEntryRefImpl::~TpuEntryRefImpl() {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
} else {
|
||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||
}
|
||||
}
|
||||
|
||||
TpuCompilationCacheEntry TpuCompilationCacheExternal::TpuEntryRefImpl::get() {
|
||||
if (entry_ == nullptr) {
|
||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||
// non-existing sharding/unsharding entries.
|
||||
return TpuCompilationCacheEntry();
|
||||
}
|
||||
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::TpuEntryRefImpl::ToSubEntryRef(
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
CompiledSubgraph* target = nullptr;
|
||||
switch (fetch_target) {
|
||||
case CompilationCacheFetchTarget::MAIN:
|
||||
target = entry_;
|
||||
break;
|
||||
case CompilationCacheFetchTarget::SHARDING:
|
||||
target = entry_->sharding_entry.get();
|
||||
break;
|
||||
case CompilationCacheFetchTarget::UNSHARDING:
|
||||
target = entry_->unsharding_entry.get();
|
||||
break;
|
||||
default:
|
||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||
}
|
||||
|
||||
if (target == nullptr) {
|
||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||
// with nullptr.
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
}
|
||||
// Otherwise, since the refcount is always on the main entry, we don't need
|
||||
// ref/unref.
|
||||
entry_ = target;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme(
|
||||
"TPU compilation cache proto lookup by uid",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_uid_.find(uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No subgraph found for uid ", uid);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second;
|
||||
if (proto_index < 0 ||
|
||||
proto_index >= cache_entry->tpu_program_group->program_size()) {
|
||||
return errors::NotFound("No proto found for core index ", proto_index,
|
||||
" in subgraph with uid ", uid);
|
||||
}
|
||||
*entry = std::unique_ptr<CompilationCacheEntryRef>(
|
||||
new TpuEntryRefImpl(this, cache_entry, proto_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::Lookup(
|
||||
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||
if (iter == entries_by_proto_key_.end()) {
|
||||
return errors::NotFound("No proto found for key ", proto_key);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second.first;
|
||||
int proto_index = iter->second.second;
|
||||
*entry = std::unique_ptr<CompilationCacheEntryRef>(
|
||||
new TpuEntryRefImpl(this, cache_entry, proto_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheExternal::CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<Status(TpuProgramGroup*)>& compile_function) {
|
||||
profiler::TraceMe subgraph_lookup_traceme(
|
||||
"TPU compilation cache subgraph lookup",
|
||||
/*level=*/2);
|
||||
|
||||
// NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
|
||||
// for the lifetime of the object, see InitializeEntry() call below.
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
std::string cache_key = FindCacheKey(subgraph_key);
|
||||
auto iter = cache_store_.find(cache_key);
|
||||
bool is_new_key = iter == cache_store_.end();
|
||||
|
||||
const std::string session_name = SessionNameFromMetadata(session_metadata);
|
||||
|
||||
CompiledSubgraph* entry = nullptr;
|
||||
if (is_new_key) {
|
||||
cache_key = ConstructCompilationCacheKey(subgraph_key);
|
||||
TpuCompilationCacheMetrics::IncrementCacheLookupCount(
|
||||
/*is_cache_hit=*/false, session_name);
|
||||
const string msg =
|
||||
strings::StrCat("TPU host compilation cache miss: cache_key(",
|
||||
cache_key, "), session_name(", session_name, ")");
|
||||
|
||||
TRACESTRING(msg);
|
||||
LOG(INFO) << msg;
|
||||
|
||||
// Check if caller has disabled compilation. Set using
|
||||
// internal::ScopedTpuCompileDisabler.
|
||||
if (!IsTpuCompilationEnabled()) {
|
||||
const string error_msg = strings::StrCat(
|
||||
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
||||
"disabled, session_name(",
|
||||
session_name, ") Debug String: ", subgraph_key.debug_string);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Cache Missed. Current cache entries: ";
|
||||
for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
|
||||
// TODO(henrytan): add DebugKey as cache_entry_debug_string to
|
||||
// TpuCompilationCacheKey.
|
||||
VLOG(2) << "Cache Debug Info: ";
|
||||
VLOG(2) << it->second->cache_entry_debug_string;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
|
||||
// The single ref on the newly-created entry is owned by the caller.
|
||||
VLOG(1) << "Before adding new entry for key " << cache_key
|
||||
<< " with session_name( " << session_name << ");"
|
||||
<< "; cache is " << cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
|
||||
<< " marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
// Note that InitializeEntry() will Release/Reacquire mu_.
|
||||
entry = InitializeEntry(cache_key, compile_function, subgraph_key);
|
||||
TRACELITERAL("TPU host compilation cache: compilation done.");
|
||||
|
||||
LOG(INFO) << strings::StrCat(
|
||||
"TPU host compilation cache: compilation done for cache_key(",
|
||||
cache_key, "), session_name(", session_name, ")");
|
||||
// If session_name is present, log some additional stats related to HBM
|
||||
// here, so that they can be associated directly to the session.
|
||||
if (!session_name.empty()) {
|
||||
entry->tpu_program_group->LogProgramMemorySummary();
|
||||
}
|
||||
} else {
|
||||
TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
|
||||
const string msg =
|
||||
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
|
||||
"), session_name(", session_name, ")");
|
||||
TRACESTRING(msg);
|
||||
VLOG(1) << msg;
|
||||
VLOG(1) << "Before refreshing entry for key " << cache_key
|
||||
<< " with session_name( " << session_name << "); cache is "
|
||||
<< cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
|
||||
<< " marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
entry = iter->second;
|
||||
// Make a new reference that is owned by the caller.
|
||||
entry->Ref();
|
||||
// Block if necessary until the subgraph has been initialized.
|
||||
mu_.Await(absl::Condition(
|
||||
+[](CompiledSubgraph* e) { return e->initialized; }, entry));
|
||||
}
|
||||
|
||||
// Let the caller know the uid of the entry.
|
||||
*uid = entry->uid;
|
||||
// Let the caller know the keys for each of the cached protos.
|
||||
*proto_key = entry->proto_key;
|
||||
*may_modify_variables = entry->tpu_program_group->may_modify_variables();
|
||||
*hlo_metadata = entry->tpu_program_group->hlo_metadatas();
|
||||
|
||||
// If the caller didn't supply a per_step_ref_holder then the caller is going
|
||||
// to manually release the reference later via a call to Release().
|
||||
if (per_step_ref_holder == nullptr) {
|
||||
++entry->external_references;
|
||||
} else {
|
||||
// The caller wants its reference to be handed off to a per-step holder that
|
||||
// will discard the reference when the step completes.
|
||||
RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
|
||||
TF_RET_CHECK(cast_ref_holder != nullptr);
|
||||
cast_ref_holder->AddRef(entry);
|
||||
}
|
||||
|
||||
// Remove the old LRU-table entry if it wasn't already marked for eviction.
|
||||
auto erased = entries_by_last_use_.erase(entry->last_use);
|
||||
// Update the LRU table indicating this entry is the most recently used.
|
||||
entry->last_use = use_counter_++;
|
||||
entries_by_last_use_[entry->last_use] = entry;
|
||||
if (erased == 0) {
|
||||
// The entry had been marked for eviction, or is newly created.
|
||||
LookupEntryMarkedForEviction(entry, removed_entries);
|
||||
}
|
||||
|
||||
// Log a little more verbosely when a key is added.
|
||||
if (VLOG_IS_ON(1) || is_new_key) {
|
||||
LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
|
||||
<< " entry for key " << cache_key << " with session_name "
|
||||
<< session_name << " cache is " << cache_store_.size()
|
||||
<< " entries (" << cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), "
|
||||
<< " marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
return entry->initialization_status;
|
||||
}
|
||||
|
||||
tensorflow::Status TpuCompilationCacheExternal::CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& cache_key,
|
||||
const tensorflow::SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<tensorflow::Status(TpuProgramGroup*)>&
|
||||
compile_function) {
|
||||
std::vector<CompiledSubgraph*> removed_entries;
|
||||
auto status = CompileIfKeyAbsentHelper(
|
||||
cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
|
||||
may_modify_variables, &removed_entries, hlo_metadata, compile_function);
|
||||
for (auto entry : removed_entries) {
|
||||
UnloadAndDestroy(entry);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -26,14 +26,11 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
|
@ -43,25 +40,37 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
constexpr char kCompilationCacheResourceName[] = "tpu_compilation_cache";
|
||||
constexpr char kCompilationCacheUnloaderResourceName[] =
|
||||
const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
|
||||
const char kCompilationCacheUnloaderResourceName[] =
|
||||
"tpu_compilation_cache_unloader";
|
||||
|
||||
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
||||
// Base class that holds references to compiled protos so that the protos are
|
||||
// not garbage-collected before being used by execute ops. Use
|
||||
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||
// ref holder object.
|
||||
class TpuCompilationRefHolder : public ResourceBase {
|
||||
public:
|
||||
~TpuCompilationRefHolder() override = default;
|
||||
};
|
||||
|
||||
class TpuCompilationCacheExternal : public ResourceBase {
|
||||
public:
|
||||
using Status = ::stream_executor::port::Status;
|
||||
|
||||
class EntryRefImpl
|
||||
: public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
|
||||
public:
|
||||
EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
|
||||
int index);
|
||||
explicit TpuCompilationCacheExternal(int64_t max_cache_size);
|
||||
~TpuCompilationCacheExternal() override;
|
||||
TpuCompilationCacheExternal(const TpuCompilationCacheExternal&) = delete;
|
||||
TpuCompilationCacheExternal& operator=(const TpuCompilationCacheExternal&) =
|
||||
delete;
|
||||
|
||||
TpuCompilationCacheEntry get() override;
|
||||
};
|
||||
|
||||
explicit TpuCompilationCacheExternal(int64 max_cache_size)
|
||||
: TpuCompilationCacheInterface(max_cache_size) {}
|
||||
Status CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& cache_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<tensorflow::Status(TpuProgramGroup*)>&
|
||||
compile_function);
|
||||
|
||||
static TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||
|
@ -73,7 +82,177 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
|||
|
||||
string DebugString() const override { return "TpuCompilationCacheExternal"; }
|
||||
|
||||
// Makes a reference holder for this cache, that can be stored in the per-step
|
||||
// resource manager and will ensure that compiled entries persist until the
|
||||
// end of a step.
|
||||
TpuCompilationRefHolder* MakePerStepRefHolder();
|
||||
|
||||
// Differences between MarkEntryForEviction and Release:
|
||||
// There are two modes of managing cache entries:
|
||||
// 1) LRU eviction + pinning; 2) manual.
|
||||
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
|
||||
// Otherwise it is manual mode (mainly used by XRT).
|
||||
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
|
||||
// entries when callers know that they do not need them anymore.
|
||||
// Release should only be used in mode 2) to explicitly remove an entry.
|
||||
|
||||
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
|
||||
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
|
||||
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
|
||||
// subgraph_uid).
|
||||
Status MarkEntryForEviction(int64 subgraph_uid);
|
||||
|
||||
// Manually discards a reference to the compiled subgraph. This should only be
|
||||
// called if per_step_ref_holder was nullptr in the corresponding call to
|
||||
// CompileIfKeyAbsent(subgraph_key, ...).
|
||||
Status Release(int64 subgraph_uid);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by key. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by uid. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||
|
||||
// Mutates the main entry ref to point to the entry's subentry
|
||||
// (for sharding/unsharding) or main entry (unchanged) representing the
|
||||
// fetch target. The entry ref needs to point to the main entry before this
|
||||
// call.
|
||||
//
|
||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||
// entry.
|
||||
Status ToSubEntryRef(CompilationCacheEntryRef* entry,
|
||||
CompilationCacheFetchTarget fetch_target) const;
|
||||
|
||||
private:
|
||||
// Wrapper for a cache entry that holds a reference to the entry until the
|
||||
// wrapper is deleted. This wrapper is the concrete type of
|
||||
// CompilationCacheEntryRef returned by Lookup.
|
||||
class TpuEntryRefImpl : public CompilationCacheEntryRef {
|
||||
public:
|
||||
TpuEntryRefImpl(TpuCompilationCacheExternal* parent,
|
||||
CompiledSubgraph* entry, int index);
|
||||
~TpuEntryRefImpl() override;
|
||||
|
||||
TpuCompilationCacheEntry get() override;
|
||||
|
||||
// Mutates this ref to point to the entry's subentry (for
|
||||
// sharding/unsharding) or main entry (unchanged) as specified by
|
||||
// fetch_target. The refcount is kept unchanged, since we only track the
|
||||
// refcount of the main entry. The entry ref needs to point to the main
|
||||
// entry before this call.
|
||||
//
|
||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||
// entry, and the original entry will be unref'ed.
|
||||
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
|
||||
|
||||
private:
|
||||
TpuCompilationCacheExternal* parent_; // Not owned.
|
||||
// A reference to entry_ is acquired in the constructor and released via
|
||||
// parent->DiscardEntryRefs in the destructor.
|
||||
CompiledSubgraph* entry_;
|
||||
// The program in entry_ that is returned by the get method.
|
||||
int index_;
|
||||
};
|
||||
|
||||
// Private implementation of the generic CompilationRefHolder that knows about
|
||||
// CompiledSubgraph entries.
|
||||
class RefHolder : public TpuCompilationRefHolder {
|
||||
public:
|
||||
explicit RefHolder(TpuCompilationCacheExternal* parent) : parent_(parent) {
|
||||
parent_->Ref();
|
||||
}
|
||||
~RefHolder() override {
|
||||
// Release our reference to the parent.
|
||||
parent_->Unref();
|
||||
}
|
||||
|
||||
// Adds entry to the list of entries that will be released when the
|
||||
// RefHolder is destroyed. Each entry is released via a call to
|
||||
// parent_->DiscardEntryRefs.
|
||||
void AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); }
|
||||
|
||||
string DebugString() const override {
|
||||
return "TpuCompilationCacheExternal::RefHolder";
|
||||
}
|
||||
|
||||
private:
|
||||
TpuCompilationCacheExternal* parent_; // Not owned.
|
||||
std::vector<CompiledSubgraph*> entries_;
|
||||
};
|
||||
|
||||
// The bulk of implementation of CompileIfKeyAbsent() with the exception
|
||||
// of unloading programs that corresponds to possibly removed cache
|
||||
// entries. The split helps to manage locking since we prefer to perform
|
||||
// unloading without holding extra locks.
|
||||
Status CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<Status(TpuProgramGroup*)>& compile_function);
|
||||
|
||||
// This is called by the cache when entry is marked for eviction; by
|
||||
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
|
||||
// an EntryRefImpl when it is destroyed. Releases one reference to entry
|
||||
// if more than 1 remains. If only one reference is left, the entry is removed
|
||||
// from cache_ and is returned to the caller; which must eventually call
|
||||
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
|
||||
// to avoid holding the lock during program unloading.
|
||||
ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
|
||||
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Convenience method called by ~RefHolder without mu_ held. Calls
|
||||
// DiscardEntryRef on every element of entries.
|
||||
void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
|
||||
|
||||
// Marks the oldest unmarked entry for eviction. Requires that there is at
|
||||
// least one such entry. In case the evicted entry had only 1 reference it
|
||||
// is removed from the cache and returned to the caller which must eventually
|
||||
// call UnloadAndDestroy.
|
||||
CompiledSubgraph* MarkOldestEntryForEviction()
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Updates datastructures to indicate that entry, which had been marked for
|
||||
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
|
||||
// entry is newly created, or an entry that has been marked for eviction but
|
||||
// not yet evicted is looked up.
|
||||
//
|
||||
// First the entry is unmarked for eviction, i.e. the cache gains a reference
|
||||
// to entry, entry's last_use field is set to be the most recent value of
|
||||
// use_counter_ and entries_by_last_use_ is updated accordingly.
|
||||
//
|
||||
// Next, the size of the cache is examined to see if any other entries need to
|
||||
// be marked for eviction now that entry has been unmarked. While the total
|
||||
// size of unmarked cached entries is greater than max_cache_size_, entries
|
||||
// are marked for eviction in LRU order. The most recently used entry is never
|
||||
// marked for eviction, so an entry larger than the max cache size will remain
|
||||
// in the cache until it is replaced by something else. In case some entries
|
||||
// actually were removed from the cache, they are a returned to the caller via
|
||||
// removed_entries. The caller must eventually delete them by calling
|
||||
// UnloadAndDestroy.
|
||||
void LookupEntryMarkedForEviction(
|
||||
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Removes the entry with given key from cache.
|
||||
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Inserts the given key and entry to cache.
|
||||
void InsertEntry(const std::string& key,
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the cache key matching given subgraph_key.
|
||||
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Creates a new entry by running initialize_programs and places it in the
|
||||
// cache to be looked up by key. The new entry is in the 'marked for eviction'
|
||||
// state (not present in entries_by_last_use_) and the caller is expected to
|
||||
|
@ -82,10 +261,61 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
|||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
||||
CompiledSubgraph* InitializeEntry(
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>&
|
||||
initialize_program,
|
||||
const std::function<Status(TpuProgramGroup*)>& initialize_program,
|
||||
const TpuCompilationCacheKey& subgraph_key)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override;
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Unloads the program associated with the entry from all local devices
|
||||
// and deletes the entry itself. It is assumed no one else has a reference
|
||||
// to it and all related keys had already been removed from the cache.
|
||||
// The call can perform device IO so no locks should be held while calling it.
|
||||
void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// The maximum size of entries that are stored in the cache before entries are
|
||||
// marked for eviction.
|
||||
const int64 max_cache_size_;
|
||||
|
||||
mutable absl::Mutex mu_;
|
||||
// The total size of entries that are stored and not marked for eviction.
|
||||
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The total size of entries that are marked for eviction.
|
||||
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The value to assign to the last_use field of the next entry that is looked
|
||||
// up.
|
||||
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// session_key_map_ and fingerprint_key_map_ are used for looking up the
|
||||
// cache_ key matching a given subgraph key. When doing a lookup, check
|
||||
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
||||
// Map from key prefix + session_handle to a cache_ key.
|
||||
std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// Map from key prefix + fingerprint to a cache_ key.
|
||||
std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// All the subgraph entries that can be looked up in the cache. An entry is
|
||||
// marked for eviction iff it is present in cache_ and not in
|
||||
// entries_by_last_use_.
|
||||
std::unordered_map<string, CompiledSubgraph*> cache_store_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// All the subgraph entries that can be looked up in the cache, indexed by
|
||||
// uid.
|
||||
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// All the protos that can be looked up in the cache, indexed by proto
|
||||
// key. The value of the map is a subgraph and the index of the proto compiled
|
||||
// for that subgraph.
|
||||
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
|
||||
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// Map from last_use to entry, used to mark entries for eviction in LRU
|
||||
// order. If an entry's last_use counter is not present as a key in
|
||||
// entries_by_last_use_ then the entry has been marked for eviction.
|
||||
std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
|
|
@ -1,355 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
|
||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Base class that holds references to compiled protos so that the protos are
|
||||
// not garbage-collected before being used by execute ops. Use
|
||||
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||
// ref holder object.
|
||||
class CompilationRefHolder : public ResourceBase {
|
||||
public:
|
||||
~CompilationRefHolder() override = default;
|
||||
};
|
||||
|
||||
// Base class for a reference to a cached tpu program. A unique_ptr to a
|
||||
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
||||
// and ensures the underlying proto is not garbage-collected until the client
|
||||
// discards the ptr.
|
||||
template <typename CacheEntryType>
|
||||
class CompilationCacheEntryRef {
|
||||
public:
|
||||
virtual ~CompilationCacheEntryRef() = default;
|
||||
|
||||
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
||||
// of the tpu::CompilationCacheEntryRef.
|
||||
virtual CacheEntryType get() = 0;
|
||||
|
||||
// Mutates this ref to point to the entry's subentry (for
|
||||
// sharding/unsharding) or main entry (unchanged) as specified by
|
||||
// fetch_target. The refcount is kept unchanged, since we only track the
|
||||
// refcount of the main entry. The entry ref needs to point to the main
|
||||
// entry before this call.
|
||||
//
|
||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||
// entry, and the original entry will be unref'ed.
|
||||
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
|
||||
};
|
||||
|
||||
class TpuCompilationCacheInterface : public ResourceBase {
|
||||
public:
|
||||
explicit TpuCompilationCacheInterface(int64 max_cache_size);
|
||||
~TpuCompilationCacheInterface() override;
|
||||
|
||||
// Ensures there is an entry for key present in the cache. By the time
|
||||
// CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
|
||||
// for key, and that entry will remain valid at least until
|
||||
// per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a
|
||||
// key that is not in the cache will evaluate compile_function to compute the
|
||||
// value to use in the entry. Subsequent calls with the same key will block
|
||||
// until compile_function completes. Other cache reads and inserts may proceed
|
||||
// on other threads while compile_function is executing. If
|
||||
// per_step_ref_holder is nullptr then the caller is responsible for calling
|
||||
// Release(subgraph_key) to manually discard its reference to the compiled
|
||||
// program, once the caller will not look up the compiled program again.
|
||||
//
|
||||
// compile_function should compile the subgraph represented by key and fill in
|
||||
// one TPUExecutableProto per model-parallel core into its passed argument. It
|
||||
// should return OK if and only if compilation succeeds. The executable proto
|
||||
// vector will be discarded on non-OK status.
|
||||
Status CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||
|
||||
// Differences between MarkEntryForEviction and Release:
|
||||
// There are two modes of managing cache entries:
|
||||
// 1) LRU eviction + pinning; 2) manual.
|
||||
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
|
||||
// Otherwise it is manual mode (mainly used by XRT).
|
||||
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
|
||||
// entries when callers know that they do not need them anymore.
|
||||
// Release should only be used in mode 2) to explicitly remove an entry.
|
||||
|
||||
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
|
||||
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
|
||||
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
|
||||
// subgraph_uid).
|
||||
Status MarkEntryForEviction(int64 subgraph_uid);
|
||||
|
||||
// Manually discards a reference to the compiled subgraph. This should only be
|
||||
// called if per_step_ref_holder was nullptr in the corresponding call to
|
||||
// CompileIfKeyAbsent(subgraph_key, ...).
|
||||
Status Release(int64 subgraph_uid);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by key. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by uid. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CacheEntryRef>* entry);
|
||||
|
||||
// Looks up the subgraph represented by uid, and returns the vector of keys,
|
||||
// one per core, corresponding to that subgraph.
|
||||
Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
|
||||
|
||||
// Makes a reference holder for this cache, that can be stored in the per-step
|
||||
// resource manager and will ensure that compiled entries persist until the
|
||||
// end of a step.
|
||||
CompilationRefHolder* MakePerStepRefHolder();
|
||||
|
||||
// Convenience method called by ~RefHolder without mu_ held. Calls
|
||||
// DiscardEntryRef on every element of entries.
|
||||
void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
|
||||
|
||||
string DebugString() const override { return "TpuCompilationCacheBase"; }
|
||||
|
||||
protected:
|
||||
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
|
||||
if (!key.has_guaranteed_const) {
|
||||
return key.prefix;
|
||||
}
|
||||
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
|
||||
key.guaranteed_const_fingerprint());
|
||||
}
|
||||
|
||||
// Private implementation of the generic CompilationRefHolder that knows about
|
||||
// CompiledSubgraph entries.
|
||||
class RefHolder : public CompilationRefHolder {
|
||||
public:
|
||||
explicit RefHolder(TpuCompilationCacheInterface* parent);
|
||||
~RefHolder() override;
|
||||
|
||||
// Adds entry to the list of entries that will be released when the
|
||||
// RefHolder is destroyed. Each entry is released via a call to
|
||||
// parent_->DiscardEntryRefs.
|
||||
void AddRef(CompiledSubgraph* entry);
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
private:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
std::vector<CompiledSubgraph*> entries_;
|
||||
};
|
||||
|
||||
// The bulk of implementation of CompileIfKeyAbsent() with the exception
|
||||
// of unloading programs that corresponds to possibly removed cache
|
||||
// entries. The split helps to manage locking since we prefer to perform
|
||||
// unloading without holding extra locks.
|
||||
Status CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||
|
||||
// This is called by the cache when entry is marked for eviction; by
|
||||
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
|
||||
// an EntryRefImpl when it is destroyed. Releases one reference to entry
|
||||
// if more than 1 remains. If only one reference is left, the entry is removed
|
||||
// from cache_ and is returned to the caller; which must eventually call
|
||||
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
|
||||
// to avoid holding the lock during program unloading.
|
||||
ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
|
||||
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Marks the oldest unmarked entry for eviction. Requires that there is at
|
||||
// least one such entry. In case the evicted entry had only 1 reference it
|
||||
// is removed from the cache and returned to the caller which must eventually
|
||||
// call UnloadAndDestroy.
|
||||
ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction()
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Updates datastructures to indicate that entry, which had been marked for
|
||||
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
|
||||
// entry is newly created, or an entry that has been marked for eviction but
|
||||
// not yet evicted is looked up.
|
||||
//
|
||||
// First the entry is unmarked for eviction, i.e. the cache gains a reference
|
||||
// to entry, entry's last_use field is set to be the most recent value of
|
||||
// use_counter_ and entries_by_last_use_ is updated accordingly.
|
||||
//
|
||||
// Next, the size of the cache is examined to see if any other entries need to
|
||||
// be marked for eviction now that entry has been unmarked. While the total
|
||||
// size of unmarked cached entries is greater than max_cache_size_, entries
|
||||
// are marked for eviction in LRU order. The most recently used entry is never
|
||||
// marked for eviction, so an entry larger than the max cache size will remain
|
||||
// in the cache until it is replaced by something else. In case some entries
|
||||
// actually were removed from the cache, they are a returned to the caller via
|
||||
// removed_entries. The caller must eventually delete them by calling
|
||||
// UnloadAndDestroy.
|
||||
void LookupEntryMarkedForEviction(
|
||||
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Removes the entry with given key from cache.
|
||||
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Inserts the given key and entry to cache.
|
||||
void InsertEntry(const string& key, CompiledSubgraph* entry)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the cache key matching given subgraph_key.
|
||||
string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Creates a new entry by running initialize_programs and places it in the
|
||||
// cache to be looked up by key. The new entry is in the 'marked for eviction'
|
||||
// state (not present in entries_by_last_use_) and the caller is expected to
|
||||
// call LookupEntryMarkedForEviction after InitializeEntry.
|
||||
//
|
||||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
||||
virtual CompiledSubgraph* InitializeEntry(
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>&
|
||||
initialize_programs,
|
||||
const TpuCompilationCacheKey& subgraph_key)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
|
||||
|
||||
// Unloads the program associated with the entry from all local devices
|
||||
// and deletes the entry itself. It is assumed no one else has a reference
|
||||
// to it and all related keys had already been removed from the cache.
|
||||
// The call can perform device IO so no locks should be held while calling it.
|
||||
void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// The maximum size of entries that are stored in the cache before entries are
|
||||
// marked for eviction.
|
||||
const int64 max_cache_size_;
|
||||
// Mutex to protect access to shared resources under multi-threading
|
||||
// environment.
|
||||
absl::Mutex mu_;
|
||||
// The total size of entries that are stored and not marked for eviction.
|
||||
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
// The total size of entries that are marked for eviction.
|
||||
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
// The value to assign to the last_use field of the next entry that is looked
|
||||
// up.
|
||||
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
// session_key_map_ and fingerprint_key_map_ are used for looking up the
|
||||
// cache_ key matching a given subgraph key. When doing a lookup, check
|
||||
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
||||
// Map from key prefix + session_handle to a cache_ key.
|
||||
absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
// Map from key prefix + fingerprint to a cache_ key.
|
||||
absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
// All the subgraph entries that can be looked up in the cache. An entry is
|
||||
// marked for eviction iff it is present in cache_ and not in
|
||||
// entries_by_last_use_.
|
||||
std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
|
||||
// All the subgraph entries that can be looked up in the cache, indexed by
|
||||
// uid.
|
||||
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
// All the protos that can be looked up in the cache, indexed by proto
|
||||
// key. The value of the map is a subgraph and the index of the proto compiled
|
||||
// for that subgraph.
|
||||
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
|
||||
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
||||
// Map from last_use to entry, used to mark entries for eviction in LRU
|
||||
// order. If an entry's last_use counter is not present as a key in
|
||||
// entries_by_last_use_ then the entry has been marked for eviction.
|
||||
std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
TpuCompilationCacheMetrics tpu_compilation_cache_metrics_;
|
||||
|
||||
private:
|
||||
TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
|
||||
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
|
||||
delete;
|
||||
};
|
||||
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme(
|
||||
"TPU compilation cache proto lookup by uid",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_uid_.find(uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No subgraph found for uid ", uid);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second;
|
||||
if (proto_index < 0 ||
|
||||
proto_index >= cache_entry->tpu_program_group->program_count()) {
|
||||
return errors::NotFound("No proto found for core index ", proto_index,
|
||||
" in subgraph with uid ", uid);
|
||||
}
|
||||
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||
if (iter == entries_by_proto_key_.end()) {
|
||||
return errors::NotFound("No proto found for key ", proto_key);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second.first;
|
||||
int proto_index = iter->second.second;
|
||||
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
|
|
@ -42,7 +42,7 @@ std::string GetName(CompilationCacheFetchTarget target) {
|
|||
} // namespace
|
||||
|
||||
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
||||
TpuCompilationCacheInterface* cache)
|
||||
TpuCompilationCacheExternal* cache)
|
||||
: cache_(cache) {}
|
||||
|
||||
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
||||
|
@ -50,19 +50,17 @@ TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
|||
}
|
||||
|
||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
const string& proto_key,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
||||
/*level=*/2);
|
||||
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
||||
proto_key, entry);
|
||||
Status s = cache_->Lookup(proto_key, entry);
|
||||
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
||||
<< s;
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
||||
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
|
||||
|
||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||
<< s;
|
||||
|
@ -71,18 +69,17 @@ Status TpuCompilationCacheLocalLookup::Lookup(
|
|||
|
||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
||||
/*level=*/2);
|
||||
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
||||
uid, proto_index, entry);
|
||||
Status s = cache_->Lookup(uid, proto_index, entry);
|
||||
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
||||
<< " in local subgraph cache status " << s;
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
||||
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
|
||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||
<< s;
|
||||
return s;
|
||||
|
|
|
@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
@ -30,11 +28,6 @@ namespace tpu {
|
|||
// and when they need to communicate over RPC.
|
||||
class TpuCompilationCacheLookup : public ResourceBase {
|
||||
public:
|
||||
using TpuCompilationCacheEntryRef =
|
||||
::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
|
||||
using EntryRefImpl =
|
||||
::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
|
||||
|
||||
~TpuCompilationCacheLookup() override = default;
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
|
@ -49,11 +42,11 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
|||
// fetch_target requests one of them, then after this call
|
||||
// (*entry)->get().get_executable() will return nullptr.
|
||||
virtual Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) = 0;
|
||||
|
||||
virtual Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
return Lookup(proto_key, std::move(entry),
|
||||
CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
|
@ -63,30 +56,33 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
|||
// returned in program. The wrapper is guaranteed to be valid only during the
|
||||
// execution of the Op requesting the proto.
|
||||
virtual Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) = 0;
|
||||
|
||||
virtual Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
return Lookup(uid, proto_index, std::move(entry),
|
||||
CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
};
|
||||
|
||||
// Forward declaration to break cycle dependency graph.
|
||||
class TpuCompilationCacheExternal;
|
||||
|
||||
// Class for looking up ISA protos when the execute and compile Op are in the
|
||||
// same address space. The proto is simply looked up in the compilation cache,
|
||||
// without any serialization taking place.
|
||||
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
||||
public:
|
||||
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
|
||||
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheExternal* cache);
|
||||
~TpuCompilationCacheLocalLookup() override;
|
||||
|
||||
Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
@ -94,10 +90,10 @@ class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
|||
private:
|
||||
// The subgraph compilation cache, in the same process address space where the
|
||||
// lookups are happening.
|
||||
TpuCompilationCacheInterface* cache_;
|
||||
TpuCompilationCacheExternal* cache_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
|
|
|
@ -28,7 +28,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
|
|
Loading…
Reference in New Issue