TPU library internal change.
PiperOrigin-RevId: 316171774 Change-Id: I85b3c2639c734e391692f568bffae0efe116e9af
This commit is contained in:
parent
9130ba73e6
commit
1e3eddd152
@ -19,9 +19,9 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tpu_compile_op_support",
|
":tpu_compile_op_support",
|
||||||
":tpu_mesh_state_interface",
|
":tpu_mesh_state_interface",
|
||||||
":tpu_program_group_interface",
|
|
||||||
":tpu_util",
|
":tpu_util",
|
||||||
":tpu_util_hdrs",
|
":tpu_util_hdrs",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
@ -30,16 +30,16 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||||
|
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
# "//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
|
||||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
|
||||||
"//tensorflow/core/tpu:tpu_configuration",
|
"//tensorflow/core/tpu:tpu_configuration",
|
||||||
"//tensorflow/core/tpu:tpu_defs",
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -157,28 +157,14 @@ cc_library(
|
|||||||
"tpu_compilation_cache_entry.h",
|
"tpu_compilation_cache_entry.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":compiled_subgraph",
|
|
||||||
":tpu_compilation_cache_proto_cc",
|
|
||||||
":tpu_executable_info_proto_cc",
|
":tpu_executable_info_proto_cc",
|
||||||
":tpu_program_group",
|
":tpu_program_group",
|
||||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core/lib/core:refcount",
|
"//tensorflow/core/lib/core:refcount",
|
||||||
"//tensorflow/core/platform:casts",
|
"//tensorflow/core/platform:casts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_compilation_cache_entry_impl",
|
|
||||||
srcs = [],
|
|
||||||
hdrs = ["tpu_compilation_cache_entry_impl.h"],
|
|
||||||
deps = [
|
|
||||||
":compiled_subgraph",
|
|
||||||
":tpu_compilation_cache_interface",
|
|
||||||
":tpu_executable_info_proto_cc",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_cache_lookup",
|
name = "tpu_compilation_cache_lookup",
|
||||||
srcs = ["tpu_compilation_cache_lookup.cc"],
|
srcs = ["tpu_compilation_cache_lookup.cc"],
|
||||||
@ -188,7 +174,6 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
":tpu_compilation_cache_external",
|
":tpu_compilation_cache_external",
|
||||||
":tpu_compilation_cache_interface",
|
|
||||||
":tpu_compilation_cache_proto_cc",
|
":tpu_compilation_cache_proto_cc",
|
||||||
"//tensorflow/core/lib/core:refcount",
|
"//tensorflow/core/lib/core:refcount",
|
||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
@ -262,35 +247,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_compilation_cache_interface",
|
|
||||||
srcs = ["tpu_compilation_cache_interface.cc"],
|
|
||||||
hdrs = ["tpu_compilation_cache_interface.h"],
|
|
||||||
deps = [
|
|
||||||
":compiled_subgraph",
|
|
||||||
":tpu_compilation_cache_key",
|
|
||||||
":tpu_compilation_cache_metrics_hdrs",
|
|
||||||
":tpu_compilation_cache_proto_cc",
|
|
||||||
":tpu_util",
|
|
||||||
":tpu_util_hdrs",
|
|
||||||
":trace_util_hdrs",
|
|
||||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
|
||||||
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_cache_external",
|
name = "tpu_compilation_cache_external",
|
||||||
srcs = ["tpu_compilation_cache_external.cc"],
|
srcs = ["tpu_compilation_cache_external.cc"],
|
||||||
@ -300,8 +256,6 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":compiled_subgraph",
|
":compiled_subgraph",
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
":tpu_compilation_cache_entry_impl",
|
|
||||||
":tpu_compilation_cache_interface",
|
|
||||||
":tpu_compilation_cache_key",
|
":tpu_compilation_cache_key",
|
||||||
":tpu_compilation_cache_metrics", # buildcleaner: keep
|
":tpu_compilation_cache_metrics", # buildcleaner: keep
|
||||||
":tpu_compilation_cache_metrics_hdrs",
|
":tpu_compilation_cache_metrics_hdrs",
|
||||||
@ -401,7 +355,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -25,9 +25,6 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
// Forward declaration to avoid circular dependency.
|
|
||||||
class TpuCompilationCacheInterface;
|
|
||||||
|
|
||||||
// Cache for compiled TPU program.
|
// Cache for compiled TPU program.
|
||||||
//
|
//
|
||||||
// Each key identifies a unique subgraph, and the value is the vector of
|
// Each key identifies a unique subgraph, and the value is the vector of
|
||||||
@ -103,7 +100,10 @@ class TpuCompilationCacheInterface;
|
|||||||
// unmarked and set to most recently used.
|
// unmarked and set to most recently used.
|
||||||
//
|
//
|
||||||
struct CompiledSubgraph : public core::RefCounted {
|
struct CompiledSubgraph : public core::RefCounted {
|
||||||
TpuCompilationCacheInterface* parent = nullptr; // Not owned.
|
// TODO(henrytan): once `TpuCompilationCache` and
|
||||||
|
// `TpuCompilationCacheExternal` inherits from `TpuCompilationCacheInterface`
|
||||||
|
// update void* with `TpuCompilationCacheInterface`
|
||||||
|
void* parent = nullptr; // Not owned.
|
||||||
|
|
||||||
bool initialized = false;
|
bool initialized = false;
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ struct CompiledSubgraph : public core::RefCounted {
|
|||||||
// owning main entry.
|
// owning main entry.
|
||||||
CompiledSubgraph* main_entry = nullptr;
|
CompiledSubgraph* main_entry = nullptr;
|
||||||
|
|
||||||
// Compiled TPU program group.
|
// Compiled Tpu program.
|
||||||
std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
|
std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
|
||||||
|
|
||||||
// Computes total program size.
|
// Computes total program size.
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
// A version of `CompilationCacheEntry` to access Tpu binary program
|
// A version of `CompilationCacheEntry` that exposes Tpu binary program
|
||||||
// `XLA_TpuProgram`.
|
// `XLA_TpuProgram`.
|
||||||
class TpuCompilationCacheEntry {
|
class TpuCompilationCacheEntry {
|
||||||
public:
|
public:
|
||||||
@ -42,6 +42,28 @@ class TpuCompilationCacheEntry {
|
|||||||
int core_index_;
|
int core_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Base class for a reference to a cached proto. A unique_ptr to a
|
||||||
|
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
||||||
|
// and ensures the underlying proto is not garbage-collected until the client
|
||||||
|
// discards the ptr.
|
||||||
|
class CompilationCacheEntryRef {
|
||||||
|
public:
|
||||||
|
virtual ~CompilationCacheEntryRef() = default;
|
||||||
|
|
||||||
|
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
||||||
|
// of the CompilationCacheEntryRef.
|
||||||
|
virtual TpuCompilationCacheEntry get() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Base class that holds references to compiled protos so that the protos are
|
||||||
|
// not garbage-collected before being used by execute ops. Use
|
||||||
|
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||||
|
// ref holder object.
|
||||||
|
class CompilationRefHolder : public ResourceBase {
|
||||||
|
public:
|
||||||
|
~CompilationRefHolder() override = default;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
|
||||||
|
|
||||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tpu {
|
|
||||||
|
|
||||||
// Wrapper for a cache entry that holds a reference to the entry until the
|
|
||||||
// wrapper is deleted. This wrapper is the concrete type of
|
|
||||||
// CompilationCacheEntryRef returned by Lookup.
|
|
||||||
template <typename CacheEntryType>
|
|
||||||
class CompilationCacheEntryRefImpl
|
|
||||||
: public CompilationCacheEntryRef<CacheEntryType> {
|
|
||||||
public:
|
|
||||||
CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
|
|
||||||
CompiledSubgraph* entry, int index);
|
|
||||||
|
|
||||||
~CompilationCacheEntryRefImpl() override;
|
|
||||||
|
|
||||||
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
|
||||||
// A reference to entry_ is acquired in the constructor and released via
|
|
||||||
// parent->DiscardEntryRefs in the destructor.
|
|
||||||
CompiledSubgraph* entry_;
|
|
||||||
// The index of the program in entry_ that is returned by the get method.
|
|
||||||
int index_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename CacheEntryType>
|
|
||||||
CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
|
|
||||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
|
||||||
: parent_(parent), entry_(entry), index_(index) {
|
|
||||||
if (entry_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (entry_->main_entry == nullptr) {
|
|
||||||
entry_->Ref();
|
|
||||||
} else {
|
|
||||||
// This is a sharding/unsharding entry nested in a main entry. Only
|
|
||||||
// refcount the main entry.
|
|
||||||
entry_->main_entry->Ref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename CacheEntryType>
|
|
||||||
CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
|
|
||||||
if (entry_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (entry_->main_entry == nullptr) {
|
|
||||||
parent_->DiscardEntryRefs({entry_});
|
|
||||||
} else {
|
|
||||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename CacheEntryType>
|
|
||||||
Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
|
|
||||||
CompilationCacheFetchTarget fetch_target) {
|
|
||||||
CompiledSubgraph* target = nullptr;
|
|
||||||
switch (fetch_target) {
|
|
||||||
case CompilationCacheFetchTarget::MAIN:
|
|
||||||
target = entry_;
|
|
||||||
break;
|
|
||||||
case CompilationCacheFetchTarget::SHARDING:
|
|
||||||
target = entry_->sharding_entry.get();
|
|
||||||
break;
|
|
||||||
case CompilationCacheFetchTarget::UNSHARDING:
|
|
||||||
target = entry_->unsharding_entry.get();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (target == nullptr) {
|
|
||||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
|
||||||
// with nullptr.
|
|
||||||
parent_->DiscardEntryRefs({entry_});
|
|
||||||
}
|
|
||||||
// Otherwise, since the refcount is always on the main entry, we don't
|
|
||||||
// need ref/unref.
|
|
||||||
entry_ = target;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
|
@ -50,6 +50,14 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
|
|||||||
entry->initialized = true;
|
entry->initialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
|
||||||
|
if (!key.has_guaranteed_const) {
|
||||||
|
return key.prefix;
|
||||||
|
}
|
||||||
|
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
|
||||||
|
key.guaranteed_const_fingerprint());
|
||||||
|
}
|
||||||
|
|
||||||
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
||||||
// data to compute the fingerprint.
|
// data to compute the fingerprint.
|
||||||
std::string GuaranteedConstFingerprint(
|
std::string GuaranteedConstFingerprint(
|
||||||
@ -115,32 +123,85 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
|
TpuCompilationCacheExternal::TpuCompilationCacheExternal(int64_t max_cache_size)
|
||||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
: max_cache_size_(max_cache_size) {
|
||||||
: CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
|
if (max_cache_size < 0) {
|
||||||
index) {}
|
LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
|
||||||
|
|
||||||
TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
|
|
||||||
if (entry_ == nullptr) {
|
|
||||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
|
||||||
// non-existing sharding/unsharding entries.
|
|
||||||
return TpuCompilationCacheEntry();
|
|
||||||
}
|
}
|
||||||
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuCompilationCacheExternal::~TpuCompilationCacheExternal() {
|
||||||
|
VLOG(1) << "TpuCompilationCacheExternal::~TpuCompilationCacheExternal()";
|
||||||
|
// A buggy client may be holding onto a reference, or a client might have
|
||||||
|
// crashed while holding onto a reference. In either case, discard all
|
||||||
|
// outstanding client references to avoid leaking storage.
|
||||||
|
for (const auto& entry : entries_by_uid_) {
|
||||||
|
while (entry.second->external_references > 0) {
|
||||||
|
TF_CHECK_OK(Release(entry.first));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!entries_by_last_use_.empty()) {
|
||||||
|
UnloadAndDestroy(MarkOldestEntryForEviction());
|
||||||
|
}
|
||||||
|
// By the time the cache is deleted all reference holders should have already
|
||||||
|
// been deleted, since they were holding references to the cache. So all
|
||||||
|
// entries should be gone at this point.
|
||||||
|
CHECK_EQ(cache_store_.size(), 0);
|
||||||
|
CHECK_EQ(entries_by_uid_.size(), 0);
|
||||||
|
CHECK_EQ(entries_by_proto_key_.size(), 0);
|
||||||
|
CHECK_EQ(cache_size_, 0);
|
||||||
|
CHECK_EQ(marked_for_eviction_size_, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TpuCompilationCacheExternal::FindCacheKey(
|
||||||
|
const TpuCompilationCacheKey& subgraph_key) const {
|
||||||
|
if (!subgraph_key.has_guaranteed_const) {
|
||||||
|
return subgraph_key.prefix;
|
||||||
|
}
|
||||||
|
auto iter = session_key_map_.find(
|
||||||
|
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
|
||||||
|
if (iter != session_key_map_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
iter = fingerprint_key_map_.find(strings::StrCat(
|
||||||
|
subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
|
||||||
|
if (iter != session_key_map_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
VLOG(1) << "No matching cache key found for key "
|
||||||
|
<< ConstructCompilationCacheKey(subgraph_key);
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TpuCompilationCacheExternal::InsertEntry(
|
||||||
|
const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
|
||||||
|
CompiledSubgraph* entry) {
|
||||||
|
entry->parent = this;
|
||||||
|
entry->subgraph_key = cache_key;
|
||||||
|
entry->uid = get_uid();
|
||||||
|
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
|
||||||
|
entry->cache_entry_debug_string = subgraph_key.prefix;
|
||||||
|
VLOG(1) << "Cache Initializing Entry Session Debug "
|
||||||
|
<< entry->cache_entry_debug_string;
|
||||||
|
|
||||||
|
if (!subgraph_key.has_guaranteed_const) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
session_key_map_.insert(std::make_pair(
|
||||||
|
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
|
||||||
|
cache_key));
|
||||||
|
fingerprint_key_map_.insert(std::make_pair(
|
||||||
|
strings::StrCat(subgraph_key.prefix,
|
||||||
|
subgraph_key.guaranteed_const_fingerprint()),
|
||||||
|
cache_key));
|
||||||
}
|
}
|
||||||
|
|
||||||
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
||||||
const string& key,
|
const string& key,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
|
const std::function<Status(TpuProgramGroup*)>& initialize_program,
|
||||||
const TpuCompilationCacheKey& subgraph_key) {
|
const TpuCompilationCacheKey& subgraph_key) {
|
||||||
CompiledSubgraph* main_entry = new CompiledSubgraph();
|
CompiledSubgraph* main_entry = new CompiledSubgraph();
|
||||||
main_entry->parent = this;
|
|
||||||
main_entry->subgraph_key = key;
|
|
||||||
main_entry->uid = get_uid();
|
|
||||||
// TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
|
|
||||||
main_entry->cache_entry_debug_string = subgraph_key.prefix;
|
|
||||||
VLOG(1) << "Cache Initializing Entry Session Debug "
|
|
||||||
<< main_entry->cache_entry_debug_string;
|
|
||||||
|
|
||||||
// Add the entry to the cache, with size zero since there are no compiled
|
// Add the entry to the cache, with size zero since there are no compiled
|
||||||
// programs in it. Once the subgraph has been compiled,
|
// programs in it. Once the subgraph has been compiled,
|
||||||
@ -151,7 +212,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
|||||||
// who created the entry. A second reference, owned by the cache, will be
|
// who created the entry. A second reference, owned by the cache, will be
|
||||||
// added below since we leave the entry in the 'marked for eviction' state
|
// added below since we leave the entry in the 'marked for eviction' state
|
||||||
// here.
|
// here.
|
||||||
InsertEntry(key, main_entry);
|
InsertEntry(key, subgraph_key, main_entry);
|
||||||
|
|
||||||
// Initialize the programs outside the lock so that other cache operations
|
// Initialize the programs outside the lock so that other cache operations
|
||||||
// can proceed during the (potentially lengthy) initialization.
|
// can proceed during the (potentially lengthy) initialization.
|
||||||
@ -259,5 +320,470 @@ TpuCompilationCacheExternal::CreateCompilationCacheKey(
|
|||||||
}
|
}
|
||||||
return key;
|
return key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TpuCompilationRefHolder* TpuCompilationCacheExternal::MakePerStepRefHolder() {
|
||||||
|
return new RefHolder(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::MarkEntryForEviction(int64 subgraph_uid) {
|
||||||
|
profiler::TraceMe key_release_traceme(
|
||||||
|
"TPU compilation cache possibly evict uid",
|
||||||
|
/*level=*/2);
|
||||||
|
CompiledSubgraph* deleted_entry = nullptr;
|
||||||
|
{
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
auto iter = entries_by_uid_.find(subgraph_uid);
|
||||||
|
if (iter == entries_by_uid_.end()) {
|
||||||
|
// If already evicted, return ok.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark entry for eviction.
|
||||||
|
CompiledSubgraph* subgraph_to_evict = iter->second;
|
||||||
|
// If there are external references, should not use this API.
|
||||||
|
if (subgraph_to_evict->external_references != 0) {
|
||||||
|
return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
|
||||||
|
" external_references greater than zero. Should "
|
||||||
|
"use TpuCompilationCache::Release.");
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
|
||||||
|
entries_by_last_use_.erase(subgraph_to_evict->last_use);
|
||||||
|
cache_size_ -= subgraph_to_evict->total_size;
|
||||||
|
marked_for_eviction_size_ += subgraph_to_evict->total_size;
|
||||||
|
|
||||||
|
// Evict if refcount exactly one, otherwise only discard cache's reference
|
||||||
|
// to the entry while the actual eviction will happen when refholder's
|
||||||
|
// references go away.
|
||||||
|
deleted_entry = DiscardEntryRef(subgraph_to_evict);
|
||||||
|
|
||||||
|
VLOG(1) << "After possibly evicting entry " << subgraph_uid
|
||||||
|
<< " refs cache is " << cache_store_.size() << " entries ("
|
||||||
|
<< cache_size_ + marked_for_eviction_size_
|
||||||
|
<< " bytes), marked for eviction "
|
||||||
|
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||||
|
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload from device cache if entry is evicted from host cache.
|
||||||
|
UnloadAndDestroy(deleted_entry);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::Release(int64 subgraph_uid) {
|
||||||
|
profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
CompiledSubgraph* deleted_entry = nullptr;
|
||||||
|
{
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
auto iter = entries_by_uid_.find(subgraph_uid);
|
||||||
|
|
||||||
|
if (iter == entries_by_uid_.end()) {
|
||||||
|
return errors::NotFound("No cache entry found for uid ", subgraph_uid);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_GT(iter->second->external_references, 0);
|
||||||
|
--iter->second->external_references;
|
||||||
|
|
||||||
|
deleted_entry = DiscardEntryRef(iter->second);
|
||||||
|
|
||||||
|
VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
|
||||||
|
<< cache_store_.size() << " entries ("
|
||||||
|
<< cache_size_ + marked_for_eviction_size_
|
||||||
|
<< " bytes), marked for eviction "
|
||||||
|
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||||
|
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||||
|
}
|
||||||
|
UnloadAndDestroy(deleted_entry);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TpuCompilationCacheExternal::UnloadAndDestroy(CompiledSubgraph* entry) {
|
||||||
|
if (!entry) return;
|
||||||
|
|
||||||
|
CHECK(entry->RefCountIsOne());
|
||||||
|
entry->tpu_program_group->UnloadAndDestroyPrograms();
|
||||||
|
entry->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t TpuCompilationCacheExternal::RemoveEntry(const string& key) {
|
||||||
|
auto erased = cache_store_.erase(key);
|
||||||
|
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
|
||||||
|
auto parsed_key_or_status = ParseCompilationCacheKey(key);
|
||||||
|
CHECK(parsed_key_or_status.status().ok());
|
||||||
|
const TpuCompilationCacheKey parsed_key =
|
||||||
|
parsed_key_or_status.ConsumeValueOrDie();
|
||||||
|
if (!parsed_key.has_guaranteed_const) {
|
||||||
|
return erased;
|
||||||
|
}
|
||||||
|
session_key_map_.erase(
|
||||||
|
strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
|
||||||
|
fingerprint_key_map_.erase(strings::StrCat(
|
||||||
|
parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
|
||||||
|
return erased;
|
||||||
|
}
|
||||||
|
|
||||||
|
ABSL_MUST_USE_RESULT CompiledSubgraph*
|
||||||
|
TpuCompilationCacheExternal::DiscardEntryRef(CompiledSubgraph* entry) {
|
||||||
|
if (entry->RefCountIsOne()) {
|
||||||
|
// The last reference to this entry is going away, so really delete it from
|
||||||
|
// the cache in such a way that it can't be restored by being looked up
|
||||||
|
// again.
|
||||||
|
|
||||||
|
// Sanity-check that it has been marked for eviction.
|
||||||
|
CHECK(entries_by_last_use_.find(entry->last_use) ==
|
||||||
|
entries_by_last_use_.end());
|
||||||
|
// Update the counter tracking how much space is taken up by entries that
|
||||||
|
// are marked for eviction.
|
||||||
|
marked_for_eviction_size_ -= entry->total_size;
|
||||||
|
|
||||||
|
// Remove the entry from the cache.
|
||||||
|
auto erased = RemoveEntry(entry->subgraph_key);
|
||||||
|
|
||||||
|
if (erased == 0) {
|
||||||
|
LOG(FATAL) << "Tried to discard nonexistent cache entry";
|
||||||
|
}
|
||||||
|
erased = entries_by_uid_.erase(entry->uid);
|
||||||
|
CHECK_EQ(erased, 1);
|
||||||
|
for (const string& key : entry->proto_key) {
|
||||||
|
erased = entries_by_proto_key_.erase(key);
|
||||||
|
CHECK_EQ(erased, 1);
|
||||||
|
}
|
||||||
|
// The actual deletion will happen outside the lock in UnloadAndDestroy().
|
||||||
|
return entry;
|
||||||
|
}
|
||||||
|
entry->Unref();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TpuCompilationCacheExternal::DiscardEntryRefs(
|
||||||
|
gtl::ArraySlice<CompiledSubgraph*> entries) {
|
||||||
|
std::vector<CompiledSubgraph*> removed_entries;
|
||||||
|
{
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
|
for (auto entry : entries) {
|
||||||
|
removed_entries.push_back(DiscardEntryRef(entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
|
||||||
|
<< " entries (" << cache_size_ + marked_for_eviction_size_
|
||||||
|
<< " bytes), marked for eviction "
|
||||||
|
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||||
|
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||||
|
}
|
||||||
|
for (auto removed_entry : removed_entries) {
|
||||||
|
UnloadAndDestroy(removed_entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ABSL_MUST_USE_RESULT CompiledSubgraph*
|
||||||
|
TpuCompilationCacheExternal::MarkOldestEntryForEviction() {
|
||||||
|
CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
|
||||||
|
VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
|
||||||
|
entries_by_last_use_.erase(entry_to_mark->last_use);
|
||||||
|
cache_size_ -= entry_to_mark->total_size;
|
||||||
|
marked_for_eviction_size_ += entry_to_mark->total_size;
|
||||||
|
// Discard the cache's reference to entry. If steps are holding onto
|
||||||
|
// references to entry it won't be deleted until the last step holding it
|
||||||
|
// completes. It stays in the cache in the meantime and can be resurrected
|
||||||
|
// by a call to CompileIfKeyAbsent if that occurs before the last reference
|
||||||
|
// expires.
|
||||||
|
return DiscardEntryRef(entry_to_mark);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TpuCompilationCacheExternal::LookupEntryMarkedForEviction(
|
||||||
|
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
|
||||||
|
// The entry was previously marked for eviction (or is newly created) so
|
||||||
|
// unmark it. Add a reference (owned by the cache), update the cache size, and
|
||||||
|
// mark something old for eviction if necessary.
|
||||||
|
entry->Ref();
|
||||||
|
marked_for_eviction_size_ -= entry->total_size;
|
||||||
|
cache_size_ += entry->total_size;
|
||||||
|
|
||||||
|
// Mark the least-recently-used non-marked entry for eviction. Never mark the
|
||||||
|
// most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
|
||||||
|
// which means there's only one entry not already marked for eviction), so
|
||||||
|
// that an entry persists in the cache even if it is larger than the allocated
|
||||||
|
// cache size.
|
||||||
|
while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
|
||||||
|
if (auto entry_to_evict = MarkOldestEntryForEviction()) {
|
||||||
|
removed_entries->push_back(entry_to_evict);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::ToSubEntryRef(
|
||||||
|
CompilationCacheEntryRef* entry,
|
||||||
|
CompilationCacheFetchTarget fetch_target) const {
|
||||||
|
return static_cast<TpuEntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuCompilationCacheExternal::TpuEntryRefImpl::TpuEntryRefImpl(
|
||||||
|
TpuCompilationCacheExternal* parent, CompiledSubgraph* entry, int index)
|
||||||
|
: parent_(parent), entry_(entry), index_(index) {
|
||||||
|
if (entry_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (entry_->main_entry == nullptr) {
|
||||||
|
entry_->Ref();
|
||||||
|
} else {
|
||||||
|
// This is a sharding/unsharding entry nested in a main entry. Only refcount
|
||||||
|
// the main entry.
|
||||||
|
entry_->main_entry->Ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuCompilationCacheExternal::TpuEntryRefImpl::~TpuEntryRefImpl() {
|
||||||
|
if (entry_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (entry_->main_entry == nullptr) {
|
||||||
|
parent_->DiscardEntryRefs({entry_});
|
||||||
|
} else {
|
||||||
|
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuCompilationCacheEntry TpuCompilationCacheExternal::TpuEntryRefImpl::get() {
|
||||||
|
if (entry_ == nullptr) {
|
||||||
|
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||||
|
// non-existing sharding/unsharding entries.
|
||||||
|
return TpuCompilationCacheEntry();
|
||||||
|
}
|
||||||
|
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::TpuEntryRefImpl::ToSubEntryRef(
|
||||||
|
CompilationCacheFetchTarget fetch_target) {
|
||||||
|
CompiledSubgraph* target = nullptr;
|
||||||
|
switch (fetch_target) {
|
||||||
|
case CompilationCacheFetchTarget::MAIN:
|
||||||
|
target = entry_;
|
||||||
|
break;
|
||||||
|
case CompilationCacheFetchTarget::SHARDING:
|
||||||
|
target = entry_->sharding_entry.get();
|
||||||
|
break;
|
||||||
|
case CompilationCacheFetchTarget::UNSHARDING:
|
||||||
|
target = entry_->unsharding_entry.get();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (target == nullptr) {
|
||||||
|
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||||
|
// with nullptr.
|
||||||
|
parent_->DiscardEntryRefs({entry_});
|
||||||
|
}
|
||||||
|
// Otherwise, since the refcount is always on the main entry, we don't need
|
||||||
|
// ref/unref.
|
||||||
|
entry_ = target;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::Lookup(
|
||||||
|
int64 uid, int proto_index,
|
||||||
|
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||||
|
entry->reset();
|
||||||
|
|
||||||
|
profiler::TraceMe proto_lookup_traceme(
|
||||||
|
"TPU compilation cache proto lookup by uid",
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
const auto iter = entries_by_uid_.find(uid);
|
||||||
|
if (iter == entries_by_uid_.end()) {
|
||||||
|
return errors::NotFound("No subgraph found for uid ", uid);
|
||||||
|
}
|
||||||
|
CompiledSubgraph* cache_entry = iter->second;
|
||||||
|
if (proto_index < 0 ||
|
||||||
|
proto_index >= cache_entry->tpu_program_group->program_size()) {
|
||||||
|
return errors::NotFound("No proto found for core index ", proto_index,
|
||||||
|
" in subgraph with uid ", uid);
|
||||||
|
}
|
||||||
|
*entry = std::unique_ptr<CompilationCacheEntryRef>(
|
||||||
|
new TpuEntryRefImpl(this, cache_entry, proto_index));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::Lookup(
|
||||||
|
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||||
|
entry->reset();
|
||||||
|
|
||||||
|
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||||
|
if (iter == entries_by_proto_key_.end()) {
|
||||||
|
return errors::NotFound("No proto found for key ", proto_key);
|
||||||
|
}
|
||||||
|
CompiledSubgraph* cache_entry = iter->second.first;
|
||||||
|
int proto_index = iter->second.second;
|
||||||
|
*entry = std::unique_ptr<CompilationCacheEntryRef>(
|
||||||
|
new TpuEntryRefImpl(this, cache_entry, proto_index));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuCompilationCacheExternal::CompileIfKeyAbsentHelper(
|
||||||
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
|
const SessionMetadata* session_metadata,
|
||||||
|
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
|
std::vector<CompiledSubgraph*>* removed_entries,
|
||||||
|
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||||
|
const std::function<Status(TpuProgramGroup*)>& compile_function) {
|
||||||
|
profiler::TraceMe subgraph_lookup_traceme(
|
||||||
|
"TPU compilation cache subgraph lookup",
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
// NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
|
||||||
|
// for the lifetime of the object, see InitializeEntry() call below.
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
|
std::string cache_key = FindCacheKey(subgraph_key);
|
||||||
|
auto iter = cache_store_.find(cache_key);
|
||||||
|
bool is_new_key = iter == cache_store_.end();
|
||||||
|
|
||||||
|
const std::string session_name = SessionNameFromMetadata(session_metadata);
|
||||||
|
|
||||||
|
CompiledSubgraph* entry = nullptr;
|
||||||
|
if (is_new_key) {
|
||||||
|
cache_key = ConstructCompilationCacheKey(subgraph_key);
|
||||||
|
TpuCompilationCacheMetrics::IncrementCacheLookupCount(
|
||||||
|
/*is_cache_hit=*/false, session_name);
|
||||||
|
const string msg =
|
||||||
|
strings::StrCat("TPU host compilation cache miss: cache_key(",
|
||||||
|
cache_key, "), session_name(", session_name, ")");
|
||||||
|
|
||||||
|
TRACESTRING(msg);
|
||||||
|
LOG(INFO) << msg;
|
||||||
|
|
||||||
|
// Check if caller has disabled compilation. Set using
|
||||||
|
// internal::ScopedTpuCompileDisabler.
|
||||||
|
if (!IsTpuCompilationEnabled()) {
|
||||||
|
const string error_msg = strings::StrCat(
|
||||||
|
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
||||||
|
"disabled, session_name(",
|
||||||
|
session_name, ") Debug String: ", subgraph_key.debug_string);
|
||||||
|
if (VLOG_IS_ON(2)) {
|
||||||
|
VLOG(2) << "Cache Missed. Current cache entries: ";
|
||||||
|
for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
|
||||||
|
// TODO(henrytan): add DebugKey as cache_entry_debug_string to
|
||||||
|
// TpuCompilationCacheKey.
|
||||||
|
VLOG(2) << "Cache Debug Info: ";
|
||||||
|
VLOG(2) << it->second->cache_entry_debug_string;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
|
||||||
|
return errors::NotFound(error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The single ref on the newly-created entry is owned by the caller.
|
||||||
|
VLOG(1) << "Before adding new entry for key " << cache_key
|
||||||
|
<< " with session_name( " << session_name << ");"
|
||||||
|
<< "; cache is " << cache_store_.size() << " entries ("
|
||||||
|
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
|
||||||
|
<< " marked for eviction "
|
||||||
|
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||||
|
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||||
|
// Note that InitializeEntry() will Release/Reacquire mu_.
|
||||||
|
entry = InitializeEntry(cache_key, compile_function, subgraph_key);
|
||||||
|
TRACELITERAL("TPU host compilation cache: compilation done.");
|
||||||
|
|
||||||
|
LOG(INFO) << strings::StrCat(
|
||||||
|
"TPU host compilation cache: compilation done for cache_key(",
|
||||||
|
cache_key, "), session_name(", session_name, ")");
|
||||||
|
// If session_name is present, log some additional stats related to HBM
|
||||||
|
// here, so that they can be associated directly to the session.
|
||||||
|
if (!session_name.empty()) {
|
||||||
|
entry->tpu_program_group->LogProgramMemorySummary();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
|
||||||
|
const string msg =
|
||||||
|
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
|
||||||
|
"), session_name(", session_name, ")");
|
||||||
|
TRACESTRING(msg);
|
||||||
|
VLOG(1) << msg;
|
||||||
|
VLOG(1) << "Before refreshing entry for key " << cache_key
|
||||||
|
<< " with session_name( " << session_name << "); cache is "
|
||||||
|
<< cache_store_.size() << " entries ("
|
||||||
|
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
|
||||||
|
<< " marked for eviction "
|
||||||
|
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||||
|
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||||
|
entry = iter->second;
|
||||||
|
// Make a new reference that is owned by the caller.
|
||||||
|
entry->Ref();
|
||||||
|
// Block if necessary until the subgraph has been initialized.
|
||||||
|
mu_.Await(absl::Condition(
|
||||||
|
+[](CompiledSubgraph* e) { return e->initialized; }, entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let the caller know the uid of the entry.
|
||||||
|
*uid = entry->uid;
|
||||||
|
// Let the caller know the keys for each of the cached protos.
|
||||||
|
*proto_key = entry->proto_key;
|
||||||
|
*may_modify_variables = entry->tpu_program_group->may_modify_variables();
|
||||||
|
*hlo_metadata = entry->tpu_program_group->hlo_metadatas();
|
||||||
|
|
||||||
|
// If the caller didn't supply a per_step_ref_holder then the caller is going
|
||||||
|
// to manually release the reference later via a call to Release().
|
||||||
|
if (per_step_ref_holder == nullptr) {
|
||||||
|
++entry->external_references;
|
||||||
|
} else {
|
||||||
|
// The caller wants its reference to be handed off to a per-step holder that
|
||||||
|
// will discard the reference when the step completes.
|
||||||
|
RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
|
||||||
|
TF_RET_CHECK(cast_ref_holder != nullptr);
|
||||||
|
cast_ref_holder->AddRef(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the old LRU-table entry if it wasn't already marked for eviction.
|
||||||
|
auto erased = entries_by_last_use_.erase(entry->last_use);
|
||||||
|
// Update the LRU table indicating this entry is the most recently used.
|
||||||
|
entry->last_use = use_counter_++;
|
||||||
|
entries_by_last_use_[entry->last_use] = entry;
|
||||||
|
if (erased == 0) {
|
||||||
|
// The entry had been marked for eviction, or is newly created.
|
||||||
|
LookupEntryMarkedForEviction(entry, removed_entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log a little more verbosely when a key is added.
|
||||||
|
if (VLOG_IS_ON(1) || is_new_key) {
|
||||||
|
LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
|
||||||
|
<< " entry for key " << cache_key << " with session_name "
|
||||||
|
<< session_name << " cache is " << cache_store_.size()
|
||||||
|
<< " entries (" << cache_size_ + marked_for_eviction_size_
|
||||||
|
<< " bytes), "
|
||||||
|
<< " marked for eviction "
|
||||||
|
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||||
|
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||||
|
}
|
||||||
|
return entry->initialization_status;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status TpuCompilationCacheExternal::CompileIfKeyAbsent(
|
||||||
|
const TpuCompilationCacheKey& cache_key,
|
||||||
|
const tensorflow::SessionMetadata* session_metadata,
|
||||||
|
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
|
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||||
|
const std::function<tensorflow::Status(TpuProgramGroup*)>&
|
||||||
|
compile_function) {
|
||||||
|
std::vector<CompiledSubgraph*> removed_entries;
|
||||||
|
auto status = CompileIfKeyAbsentHelper(
|
||||||
|
cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
|
||||||
|
may_modify_variables, &removed_entries, hlo_metadata, compile_function);
|
||||||
|
for (auto entry : removed_entries) {
|
||||||
|
UnloadAndDestroy(entry);
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -26,14 +26,11 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/platform/refcount.h"
|
#include "tensorflow/core/platform/refcount.h"
|
||||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||||
@ -43,25 +40,37 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
constexpr char kCompilationCacheResourceName[] = "tpu_compilation_cache";
|
const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
|
||||||
constexpr char kCompilationCacheUnloaderResourceName[] =
|
const char kCompilationCacheUnloaderResourceName[] =
|
||||||
"tpu_compilation_cache_unloader";
|
"tpu_compilation_cache_unloader";
|
||||||
|
|
||||||
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
// Base class that holds references to compiled protos so that the protos are
|
||||||
|
// not garbage-collected before being used by execute ops. Use
|
||||||
|
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||||
|
// ref holder object.
|
||||||
|
class TpuCompilationRefHolder : public ResourceBase {
|
||||||
|
public:
|
||||||
|
~TpuCompilationRefHolder() override = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TpuCompilationCacheExternal : public ResourceBase {
|
||||||
public:
|
public:
|
||||||
using Status = ::stream_executor::port::Status;
|
using Status = ::stream_executor::port::Status;
|
||||||
|
|
||||||
class EntryRefImpl
|
explicit TpuCompilationCacheExternal(int64_t max_cache_size);
|
||||||
: public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
|
~TpuCompilationCacheExternal() override;
|
||||||
public:
|
TpuCompilationCacheExternal(const TpuCompilationCacheExternal&) = delete;
|
||||||
EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
|
TpuCompilationCacheExternal& operator=(const TpuCompilationCacheExternal&) =
|
||||||
int index);
|
delete;
|
||||||
|
|
||||||
TpuCompilationCacheEntry get() override;
|
Status CompileIfKeyAbsent(
|
||||||
};
|
const TpuCompilationCacheKey& cache_key,
|
||||||
|
const SessionMetadata* session_metadata,
|
||||||
explicit TpuCompilationCacheExternal(int64 max_cache_size)
|
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
: TpuCompilationCacheInterface(max_cache_size) {}
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
|
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||||
|
const std::function<tensorflow::Status(TpuProgramGroup*)>&
|
||||||
|
compile_function);
|
||||||
|
|
||||||
static TpuCompilationCacheKey CreateCompilationCacheKey(
|
static TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||||
@ -73,7 +82,177 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
|||||||
|
|
||||||
string DebugString() const override { return "TpuCompilationCacheExternal"; }
|
string DebugString() const override { return "TpuCompilationCacheExternal"; }
|
||||||
|
|
||||||
|
// Makes a reference holder for this cache, that can be stored in the per-step
|
||||||
|
// resource manager and will ensure that compiled entries persist until the
|
||||||
|
// end of a step.
|
||||||
|
TpuCompilationRefHolder* MakePerStepRefHolder();
|
||||||
|
|
||||||
|
// Differences between MarkEntryForEviction and Release:
|
||||||
|
// There are two modes of managing cache entries:
|
||||||
|
// 1) LRU eviction + pinning; 2) manual.
|
||||||
|
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
|
||||||
|
// Otherwise it is manual mode (mainly used by XRT).
|
||||||
|
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
|
||||||
|
// entries when callers know that they do not need them anymore.
|
||||||
|
// Release should only be used in mode 2) to explicitly remove an entry.
|
||||||
|
|
||||||
|
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
|
||||||
|
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
|
||||||
|
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
|
||||||
|
// subgraph_uid).
|
||||||
|
Status MarkEntryForEviction(int64 subgraph_uid);
|
||||||
|
|
||||||
|
// Manually discards a reference to the compiled subgraph. This should only be
|
||||||
|
// called if per_step_ref_holder was nullptr in the corresponding call to
|
||||||
|
// CompileIfKeyAbsent(subgraph_key, ...).
|
||||||
|
Status Release(int64 subgraph_uid);
|
||||||
|
|
||||||
|
// Looks up an executable corresponding to the model-parallel core index of
|
||||||
|
// the subgraph represented by key. On success a pointer to an EntryRef
|
||||||
|
// holding the program is returned in entry.
|
||||||
|
Status Lookup(const string& proto_key,
|
||||||
|
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||||
|
|
||||||
|
// Looks up an executable corresponding to the model-parallel core index of
|
||||||
|
// the subgraph represented by uid. On success a pointer to an EntryRef
|
||||||
|
// holding the program is returned in entry.
|
||||||
|
Status Lookup(int64 uid, int proto_index,
|
||||||
|
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||||
|
|
||||||
|
// Mutates the main entry ref to point to the entry's subentry
|
||||||
|
// (for sharding/unsharding) or main entry (unchanged) representing the
|
||||||
|
// fetch target. The entry ref needs to point to the main entry before this
|
||||||
|
// call.
|
||||||
|
//
|
||||||
|
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||||
|
// entry.
|
||||||
|
Status ToSubEntryRef(CompilationCacheEntryRef* entry,
|
||||||
|
CompilationCacheFetchTarget fetch_target) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Wrapper for a cache entry that holds a reference to the entry until the
|
||||||
|
// wrapper is deleted. This wrapper is the concrete type of
|
||||||
|
// CompilationCacheEntryRef returned by Lookup.
|
||||||
|
class TpuEntryRefImpl : public CompilationCacheEntryRef {
|
||||||
|
public:
|
||||||
|
TpuEntryRefImpl(TpuCompilationCacheExternal* parent,
|
||||||
|
CompiledSubgraph* entry, int index);
|
||||||
|
~TpuEntryRefImpl() override;
|
||||||
|
|
||||||
|
TpuCompilationCacheEntry get() override;
|
||||||
|
|
||||||
|
// Mutates this ref to point to the entry's subentry (for
|
||||||
|
// sharding/unsharding) or main entry (unchanged) as specified by
|
||||||
|
// fetch_target. The refcount is kept unchanged, since we only track the
|
||||||
|
// refcount of the main entry. The entry ref needs to point to the main
|
||||||
|
// entry before this call.
|
||||||
|
//
|
||||||
|
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||||
|
// entry, and the original entry will be unref'ed.
|
||||||
|
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
|
||||||
|
|
||||||
|
private:
|
||||||
|
TpuCompilationCacheExternal* parent_; // Not owned.
|
||||||
|
// A reference to entry_ is acquired in the constructor and released via
|
||||||
|
// parent->DiscardEntryRefs in the destructor.
|
||||||
|
CompiledSubgraph* entry_;
|
||||||
|
// The program in entry_ that is returned by the get method.
|
||||||
|
int index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Private implementation of the generic CompilationRefHolder that knows about
|
||||||
|
// CompiledSubgraph entries.
|
||||||
|
class RefHolder : public TpuCompilationRefHolder {
|
||||||
|
public:
|
||||||
|
explicit RefHolder(TpuCompilationCacheExternal* parent) : parent_(parent) {
|
||||||
|
parent_->Ref();
|
||||||
|
}
|
||||||
|
~RefHolder() override {
|
||||||
|
// Release our reference to the parent.
|
||||||
|
parent_->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds entry to the list of entries that will be released when the
|
||||||
|
// RefHolder is destroyed. Each entry is released via a call to
|
||||||
|
// parent_->DiscardEntryRefs.
|
||||||
|
void AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); }
|
||||||
|
|
||||||
|
string DebugString() const override {
|
||||||
|
return "TpuCompilationCacheExternal::RefHolder";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TpuCompilationCacheExternal* parent_; // Not owned.
|
||||||
|
std::vector<CompiledSubgraph*> entries_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The bulk of implementation of CompileIfKeyAbsent() with the exception
|
||||||
|
// of unloading programs that corresponds to possibly removed cache
|
||||||
|
// entries. The split helps to manage locking since we prefer to perform
|
||||||
|
// unloading without holding extra locks.
|
||||||
|
Status CompileIfKeyAbsentHelper(
|
||||||
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
|
const SessionMetadata* session_metadata,
|
||||||
|
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
|
std::vector<CompiledSubgraph*>* removed_entries,
|
||||||
|
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||||
|
const std::function<Status(TpuProgramGroup*)>& compile_function);
|
||||||
|
|
||||||
|
// This is called by the cache when entry is marked for eviction; by
|
||||||
|
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
|
||||||
|
// an EntryRefImpl when it is destroyed. Releases one reference to entry
|
||||||
|
// if more than 1 remains. If only one reference is left, the entry is removed
|
||||||
|
// from cache_ and is returned to the caller; which must eventually call
|
||||||
|
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
|
||||||
|
// to avoid holding the lock during program unloading.
|
||||||
|
ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
|
||||||
|
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
// Convenience method called by ~RefHolder without mu_ held. Calls
|
||||||
|
// DiscardEntryRef on every element of entries.
|
||||||
|
void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
|
||||||
|
|
||||||
|
// Marks the oldest unmarked entry for eviction. Requires that there is at
|
||||||
|
// least one such entry. In case the evicted entry had only 1 reference it
|
||||||
|
// is removed from the cache and returned to the caller which must eventually
|
||||||
|
// call UnloadAndDestroy.
|
||||||
|
CompiledSubgraph* MarkOldestEntryForEviction()
|
||||||
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Updates datastructures to indicate that entry, which had been marked for
|
||||||
|
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
|
||||||
|
// entry is newly created, or an entry that has been marked for eviction but
|
||||||
|
// not yet evicted is looked up.
|
||||||
|
//
|
||||||
|
// First the entry is unmarked for eviction, i.e. the cache gains a reference
|
||||||
|
// to entry, entry's last_use field is set to be the most recent value of
|
||||||
|
// use_counter_ and entries_by_last_use_ is updated accordingly.
|
||||||
|
//
|
||||||
|
// Next, the size of the cache is examined to see if any other entries need to
|
||||||
|
// be marked for eviction now that entry has been unmarked. While the total
|
||||||
|
// size of unmarked cached entries is greater than max_cache_size_, entries
|
||||||
|
// are marked for eviction in LRU order. The most recently used entry is never
|
||||||
|
// marked for eviction, so an entry larger than the max cache size will remain
|
||||||
|
// in the cache until it is replaced by something else. In case some entries
|
||||||
|
// actually were removed from the cache, they are a returned to the caller via
|
||||||
|
// removed_entries. The caller must eventually delete them by calling
|
||||||
|
// UnloadAndDestroy.
|
||||||
|
void LookupEntryMarkedForEviction(
|
||||||
|
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
|
||||||
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Removes the entry with given key from cache.
|
||||||
|
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Inserts the given key and entry to cache.
|
||||||
|
void InsertEntry(const std::string& key,
|
||||||
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
|
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Returns the cache key matching given subgraph_key.
|
||||||
|
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
|
||||||
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Creates a new entry by running initialize_programs and places it in the
|
// Creates a new entry by running initialize_programs and places it in the
|
||||||
// cache to be looked up by key. The new entry is in the 'marked for eviction'
|
// cache to be looked up by key. The new entry is in the 'marked for eviction'
|
||||||
// state (not present in entries_by_last_use_) and the caller is expected to
|
// state (not present in entries_by_last_use_) and the caller is expected to
|
||||||
@ -82,10 +261,61 @@ class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
|||||||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
||||||
CompiledSubgraph* InitializeEntry(
|
CompiledSubgraph* InitializeEntry(
|
||||||
const string& key,
|
const string& key,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>&
|
const std::function<Status(TpuProgramGroup*)>& initialize_program,
|
||||||
initialize_program,
|
|
||||||
const TpuCompilationCacheKey& subgraph_key)
|
const TpuCompilationCacheKey& subgraph_key)
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override;
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Unloads the program associated with the entry from all local devices
|
||||||
|
// and deletes the entry itself. It is assumed no one else has a reference
|
||||||
|
// to it and all related keys had already been removed from the cache.
|
||||||
|
// The call can perform device IO so no locks should be held while calling it.
|
||||||
|
void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
// The maximum size of entries that are stored in the cache before entries are
|
||||||
|
// marked for eviction.
|
||||||
|
const int64 max_cache_size_;
|
||||||
|
|
||||||
|
mutable absl::Mutex mu_;
|
||||||
|
// The total size of entries that are stored and not marked for eviction.
|
||||||
|
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||||
|
|
||||||
|
// The total size of entries that are marked for eviction.
|
||||||
|
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||||
|
|
||||||
|
// The value to assign to the last_use field of the next entry that is looked
|
||||||
|
// up.
|
||||||
|
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
|
||||||
|
|
||||||
|
// session_key_map_ and fingerprint_key_map_ are used for looking up the
|
||||||
|
// cache_ key matching a given subgraph key. When doing a lookup, check
|
||||||
|
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
||||||
|
// Map from key prefix + session_handle to a cache_ key.
|
||||||
|
std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// Map from key prefix + fingerprint to a cache_ key.
|
||||||
|
std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// All the subgraph entries that can be looked up in the cache. An entry is
|
||||||
|
// marked for eviction iff it is present in cache_ and not in
|
||||||
|
// entries_by_last_use_.
|
||||||
|
std::unordered_map<string, CompiledSubgraph*> cache_store_
|
||||||
|
ABSL_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// All the subgraph entries that can be looked up in the cache, indexed by
|
||||||
|
// uid.
|
||||||
|
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
||||||
|
ABSL_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// All the protos that can be looked up in the cache, indexed by proto
|
||||||
|
// key. The value of the map is a subgraph and the index of the proto compiled
|
||||||
|
// for that subgraph.
|
||||||
|
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
|
||||||
|
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// Map from last_use to entry, used to mark entries for eviction in LRU
|
||||||
|
// order. If an entry's last_use counter is not present as a key in
|
||||||
|
// entries_by_last_use_ then the entry has been marked for eviction.
|
||||||
|
std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
@ -1,355 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
|
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/base/thread_annotations.h"
|
|
||||||
#include "absl/container/node_hash_map.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/synchronization/mutex.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tpu {
|
|
||||||
|
|
||||||
// Base class that holds references to compiled protos so that the protos are
|
|
||||||
// not garbage-collected before being used by execute ops. Use
|
|
||||||
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
|
||||||
// ref holder object.
|
|
||||||
class CompilationRefHolder : public ResourceBase {
|
|
||||||
public:
|
|
||||||
~CompilationRefHolder() override = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Base class for a reference to a cached tpu program. A unique_ptr to a
|
|
||||||
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
|
||||||
// and ensures the underlying proto is not garbage-collected until the client
|
|
||||||
// discards the ptr.
|
|
||||||
template <typename CacheEntryType>
|
|
||||||
class CompilationCacheEntryRef {
|
|
||||||
public:
|
|
||||||
virtual ~CompilationCacheEntryRef() = default;
|
|
||||||
|
|
||||||
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
|
||||||
// of the tpu::CompilationCacheEntryRef.
|
|
||||||
virtual CacheEntryType get() = 0;
|
|
||||||
|
|
||||||
// Mutates this ref to point to the entry's subentry (for
|
|
||||||
// sharding/unsharding) or main entry (unchanged) as specified by
|
|
||||||
// fetch_target. The refcount is kept unchanged, since we only track the
|
|
||||||
// refcount of the main entry. The entry ref needs to point to the main
|
|
||||||
// entry before this call.
|
|
||||||
//
|
|
||||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
|
||||||
// entry, and the original entry will be unref'ed.
|
|
||||||
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TpuCompilationCacheInterface : public ResourceBase {
|
|
||||||
public:
|
|
||||||
explicit TpuCompilationCacheInterface(int64 max_cache_size);
|
|
||||||
~TpuCompilationCacheInterface() override;
|
|
||||||
|
|
||||||
// Ensures there is an entry for key present in the cache. By the time
|
|
||||||
// CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
|
|
||||||
// for key, and that entry will remain valid at least until
|
|
||||||
// per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a
|
|
||||||
// key that is not in the cache will evaluate compile_function to compute the
|
|
||||||
// value to use in the entry. Subsequent calls with the same key will block
|
|
||||||
// until compile_function completes. Other cache reads and inserts may proceed
|
|
||||||
// on other threads while compile_function is executing. If
|
|
||||||
// per_step_ref_holder is nullptr then the caller is responsible for calling
|
|
||||||
// Release(subgraph_key) to manually discard its reference to the compiled
|
|
||||||
// program, once the caller will not look up the compiled program again.
|
|
||||||
//
|
|
||||||
// compile_function should compile the subgraph represented by key and fill in
|
|
||||||
// one TPUExecutableProto per model-parallel core into its passed argument. It
|
|
||||||
// should return OK if and only if compilation succeeds. The executable proto
|
|
||||||
// vector will be discarded on non-OK status.
|
|
||||||
Status CompileIfKeyAbsent(
|
|
||||||
const TpuCompilationCacheKey& subgraph_key,
|
|
||||||
const SessionMetadata* session_metadata,
|
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
|
||||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
|
||||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
|
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
|
||||||
|
|
||||||
// Differences between MarkEntryForEviction and Release:
|
|
||||||
// There are two modes of managing cache entries:
|
|
||||||
// 1) LRU eviction + pinning; 2) manual.
|
|
||||||
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
|
|
||||||
// Otherwise it is manual mode (mainly used by XRT).
|
|
||||||
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
|
|
||||||
// entries when callers know that they do not need them anymore.
|
|
||||||
// Release should only be used in mode 2) to explicitly remove an entry.
|
|
||||||
|
|
||||||
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
|
|
||||||
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
|
|
||||||
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
|
|
||||||
// subgraph_uid).
|
|
||||||
Status MarkEntryForEviction(int64 subgraph_uid);
|
|
||||||
|
|
||||||
// Manually discards a reference to the compiled subgraph. This should only be
|
|
||||||
// called if per_step_ref_holder was nullptr in the corresponding call to
|
|
||||||
// CompileIfKeyAbsent(subgraph_key, ...).
|
|
||||||
Status Release(int64 subgraph_uid);
|
|
||||||
|
|
||||||
// Looks up an executable corresponding to the model-parallel core index of
|
|
||||||
// the subgraph represented by key. On success a pointer to an EntryRef
|
|
||||||
// holding the program is returned in entry.
|
|
||||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
|
||||||
Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
|
|
||||||
|
|
||||||
// Looks up an executable corresponding to the model-parallel core index of
|
|
||||||
// the subgraph represented by uid. On success a pointer to an EntryRef
|
|
||||||
// holding the program is returned in entry.
|
|
||||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
|
||||||
Status Lookup(int64 uid, int proto_index,
|
|
||||||
std::unique_ptr<CacheEntryRef>* entry);
|
|
||||||
|
|
||||||
// Looks up the subgraph represented by uid, and returns the vector of keys,
|
|
||||||
// one per core, corresponding to that subgraph.
|
|
||||||
Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
|
|
||||||
|
|
||||||
// Makes a reference holder for this cache, that can be stored in the per-step
|
|
||||||
// resource manager and will ensure that compiled entries persist until the
|
|
||||||
// end of a step.
|
|
||||||
CompilationRefHolder* MakePerStepRefHolder();
|
|
||||||
|
|
||||||
// Convenience method called by ~RefHolder without mu_ held. Calls
|
|
||||||
// DiscardEntryRef on every element of entries.
|
|
||||||
void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
|
|
||||||
|
|
||||||
string DebugString() const override { return "TpuCompilationCacheBase"; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
|
|
||||||
if (!key.has_guaranteed_const) {
|
|
||||||
return key.prefix;
|
|
||||||
}
|
|
||||||
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
|
|
||||||
key.guaranteed_const_fingerprint());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Private implementation of the generic CompilationRefHolder that knows about
|
|
||||||
// CompiledSubgraph entries.
|
|
||||||
class RefHolder : public CompilationRefHolder {
|
|
||||||
public:
|
|
||||||
explicit RefHolder(TpuCompilationCacheInterface* parent);
|
|
||||||
~RefHolder() override;
|
|
||||||
|
|
||||||
// Adds entry to the list of entries that will be released when the
|
|
||||||
// RefHolder is destroyed. Each entry is released via a call to
|
|
||||||
// parent_->DiscardEntryRefs.
|
|
||||||
void AddRef(CompiledSubgraph* entry);
|
|
||||||
|
|
||||||
string DebugString() const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
|
||||||
std::vector<CompiledSubgraph*> entries_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// The bulk of implementation of CompileIfKeyAbsent() with the exception
|
|
||||||
// of unloading programs that corresponds to possibly removed cache
|
|
||||||
// entries. The split helps to manage locking since we prefer to perform
|
|
||||||
// unloading without holding extra locks.
|
|
||||||
Status CompileIfKeyAbsentHelper(
|
|
||||||
const TpuCompilationCacheKey& subgraph_key,
|
|
||||||
const SessionMetadata* session_metadata,
|
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
|
||||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
|
||||||
std::vector<CompiledSubgraph*>* removed_entries,
|
|
||||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadatas,
|
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
|
||||||
|
|
||||||
// This is called by the cache when entry is marked for eviction; by
|
|
||||||
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
|
|
||||||
// an EntryRefImpl when it is destroyed. Releases one reference to entry
|
|
||||||
// if more than 1 remains. If only one reference is left, the entry is removed
|
|
||||||
// from cache_ and is returned to the caller; which must eventually call
|
|
||||||
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
|
|
||||||
// to avoid holding the lock during program unloading.
|
|
||||||
ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
|
|
||||||
CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Marks the oldest unmarked entry for eviction. Requires that there is at
|
|
||||||
// least one such entry. In case the evicted entry had only 1 reference it
|
|
||||||
// is removed from the cache and returned to the caller which must eventually
|
|
||||||
// call UnloadAndDestroy.
|
|
||||||
ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction()
|
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Updates datastructures to indicate that entry, which had been marked for
|
|
||||||
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
|
|
||||||
// entry is newly created, or an entry that has been marked for eviction but
|
|
||||||
// not yet evicted is looked up.
|
|
||||||
//
|
|
||||||
// First the entry is unmarked for eviction, i.e. the cache gains a reference
|
|
||||||
// to entry, entry's last_use field is set to be the most recent value of
|
|
||||||
// use_counter_ and entries_by_last_use_ is updated accordingly.
|
|
||||||
//
|
|
||||||
// Next, the size of the cache is examined to see if any other entries need to
|
|
||||||
// be marked for eviction now that entry has been unmarked. While the total
|
|
||||||
// size of unmarked cached entries is greater than max_cache_size_, entries
|
|
||||||
// are marked for eviction in LRU order. The most recently used entry is never
|
|
||||||
// marked for eviction, so an entry larger than the max cache size will remain
|
|
||||||
// in the cache until it is replaced by something else. In case some entries
|
|
||||||
// actually were removed from the cache, they are a returned to the caller via
|
|
||||||
// removed_entries. The caller must eventually delete them by calling
|
|
||||||
// UnloadAndDestroy.
|
|
||||||
void LookupEntryMarkedForEviction(
|
|
||||||
CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
|
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Removes the entry with given key from cache.
|
|
||||||
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Inserts the given key and entry to cache.
|
|
||||||
void InsertEntry(const string& key, CompiledSubgraph* entry)
|
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Returns the cache key matching given subgraph_key.
|
|
||||||
string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
|
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Creates a new entry by running initialize_programs and places it in the
|
|
||||||
// cache to be looked up by key. The new entry is in the 'marked for eviction'
|
|
||||||
// state (not present in entries_by_last_use_) and the caller is expected to
|
|
||||||
// call LookupEntryMarkedForEviction after InitializeEntry.
|
|
||||||
//
|
|
||||||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
|
||||||
virtual CompiledSubgraph* InitializeEntry(
|
|
||||||
const string& key,
|
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>&
|
|
||||||
initialize_programs,
|
|
||||||
const TpuCompilationCacheKey& subgraph_key)
|
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
|
|
||||||
|
|
||||||
// Unloads the program associated with the entry from all local devices
|
|
||||||
// and deletes the entry itself. It is assumed no one else has a reference
|
|
||||||
// to it and all related keys had already been removed from the cache.
|
|
||||||
// The call can perform device IO so no locks should be held while calling it.
|
|
||||||
void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
|
|
||||||
|
|
||||||
// The maximum size of entries that are stored in the cache before entries are
|
|
||||||
// marked for eviction.
|
|
||||||
const int64 max_cache_size_;
|
|
||||||
// Mutex to protect access to shared resources under multi-threading
|
|
||||||
// environment.
|
|
||||||
absl::Mutex mu_;
|
|
||||||
// The total size of entries that are stored and not marked for eviction.
|
|
||||||
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
|
|
||||||
// The total size of entries that are marked for eviction.
|
|
||||||
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
|
|
||||||
// The value to assign to the last_use field of the next entry that is looked
|
|
||||||
// up.
|
|
||||||
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
|
|
||||||
// session_key_map_ and fingerprint_key_map_ are used for looking up the
|
|
||||||
// cache_ key matching a given subgraph key. When doing a lookup, check
|
|
||||||
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
|
||||||
// Map from key prefix + session_handle to a cache_ key.
|
|
||||||
absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
|
||||||
// Map from key prefix + fingerprint to a cache_ key.
|
|
||||||
absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
|
||||||
// All the subgraph entries that can be looked up in the cache. An entry is
|
|
||||||
// marked for eviction iff it is present in cache_ and not in
|
|
||||||
// entries_by_last_use_.
|
|
||||||
std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
|
|
||||||
// All the subgraph entries that can be looked up in the cache, indexed by
|
|
||||||
// uid.
|
|
||||||
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
|
||||||
ABSL_GUARDED_BY(mu_);
|
|
||||||
// All the protos that can be looked up in the cache, indexed by proto
|
|
||||||
// key. The value of the map is a subgraph and the index of the proto compiled
|
|
||||||
// for that subgraph.
|
|
||||||
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
|
|
||||||
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
|
||||||
// Map from last_use to entry, used to mark entries for eviction in LRU
|
|
||||||
// order. If an entry's last_use counter is not present as a key in
|
|
||||||
// entries_by_last_use_ then the entry has been marked for eviction.
|
|
||||||
std::map<int64, CompiledSubgraph*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
|
|
||||||
|
|
||||||
TpuCompilationCacheMetrics tpu_compilation_cache_metrics_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
|
|
||||||
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
|
|
||||||
delete;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
|
||||||
Status TpuCompilationCacheInterface::Lookup(
|
|
||||||
int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
|
|
||||||
entry->reset();
|
|
||||||
|
|
||||||
profiler::TraceMe proto_lookup_traceme(
|
|
||||||
"TPU compilation cache proto lookup by uid",
|
|
||||||
/*level=*/2);
|
|
||||||
|
|
||||||
absl::MutexLock lock(&mu_);
|
|
||||||
const auto iter = entries_by_uid_.find(uid);
|
|
||||||
if (iter == entries_by_uid_.end()) {
|
|
||||||
return errors::NotFound("No subgraph found for uid ", uid);
|
|
||||||
}
|
|
||||||
CompiledSubgraph* cache_entry = iter->second;
|
|
||||||
if (proto_index < 0 ||
|
|
||||||
proto_index >= cache_entry->tpu_program_group->program_count()) {
|
|
||||||
return errors::NotFound("No proto found for core index ", proto_index,
|
|
||||||
" in subgraph with uid ", uid);
|
|
||||||
}
|
|
||||||
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
|
||||||
Status TpuCompilationCacheInterface::Lookup(
|
|
||||||
const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
|
|
||||||
entry->reset();
|
|
||||||
|
|
||||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
|
||||||
/*level=*/2);
|
|
||||||
|
|
||||||
absl::MutexLock lock(&mu_);
|
|
||||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
|
||||||
if (iter == entries_by_proto_key_.end()) {
|
|
||||||
return errors::NotFound("No proto found for key ", proto_key);
|
|
||||||
}
|
|
||||||
CompiledSubgraph* cache_entry = iter->second.first;
|
|
||||||
int proto_index = iter->second.second;
|
|
||||||
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
|
|
@ -42,7 +42,7 @@ std::string GetName(CompilationCacheFetchTarget target) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
||||||
TpuCompilationCacheInterface* cache)
|
TpuCompilationCacheExternal* cache)
|
||||||
: cache_(cache) {}
|
: cache_(cache) {}
|
||||||
|
|
||||||
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
||||||
@ -50,19 +50,17 @@ TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||||
const string& proto_key,
|
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
|
||||||
CompilationCacheFetchTarget fetch_target) {
|
CompilationCacheFetchTarget fetch_target) {
|
||||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
Status s = cache_->Lookup(proto_key, entry);
|
||||||
proto_key, entry);
|
|
||||||
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
||||||
<< s;
|
<< s;
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
|
||||||
|
|
||||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||||
<< s;
|
<< s;
|
||||||
@ -71,18 +69,17 @@ Status TpuCompilationCacheLocalLookup::Lookup(
|
|||||||
|
|
||||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||||
int64 uid, int proto_index,
|
int64 uid, int proto_index,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) {
|
CompilationCacheFetchTarget fetch_target) {
|
||||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
Status s = cache_->Lookup(uid, proto_index, entry);
|
||||||
uid, proto_index, entry);
|
|
||||||
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
||||||
<< " in local subgraph cache status " << s;
|
<< " in local subgraph cache status " << s;
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
|
||||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||||
<< s;
|
<< s;
|
||||||
return s;
|
return s;
|
||||||
|
@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
|
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
|
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
@ -30,11 +28,6 @@ namespace tpu {
|
|||||||
// and when they need to communicate over RPC.
|
// and when they need to communicate over RPC.
|
||||||
class TpuCompilationCacheLookup : public ResourceBase {
|
class TpuCompilationCacheLookup : public ResourceBase {
|
||||||
public:
|
public:
|
||||||
using TpuCompilationCacheEntryRef =
|
|
||||||
::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
|
|
||||||
using EntryRefImpl =
|
|
||||||
::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
|
|
||||||
|
|
||||||
~TpuCompilationCacheLookup() override = default;
|
~TpuCompilationCacheLookup() override = default;
|
||||||
|
|
||||||
// Looks up an executable corresponding to the model-parallel core index of
|
// Looks up an executable corresponding to the model-parallel core index of
|
||||||
@ -49,11 +42,11 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
|||||||
// fetch_target requests one of them, then after this call
|
// fetch_target requests one of them, then after this call
|
||||||
// (*entry)->get().get_executable() will return nullptr.
|
// (*entry)->get().get_executable() will return nullptr.
|
||||||
virtual Status Lookup(const string& proto_key,
|
virtual Status Lookup(const string& proto_key,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) = 0;
|
CompilationCacheFetchTarget fetch_target) = 0;
|
||||||
|
|
||||||
virtual Status Lookup(const string& proto_key,
|
virtual Status Lookup(const string& proto_key,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
|
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||||
return Lookup(proto_key, std::move(entry),
|
return Lookup(proto_key, std::move(entry),
|
||||||
CompilationCacheFetchTarget::MAIN);
|
CompilationCacheFetchTarget::MAIN);
|
||||||
}
|
}
|
||||||
@ -63,30 +56,33 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
|||||||
// returned in program. The wrapper is guaranteed to be valid only during the
|
// returned in program. The wrapper is guaranteed to be valid only during the
|
||||||
// execution of the Op requesting the proto.
|
// execution of the Op requesting the proto.
|
||||||
virtual Status Lookup(int64 uid, int proto_index,
|
virtual Status Lookup(int64 uid, int proto_index,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) = 0;
|
CompilationCacheFetchTarget fetch_target) = 0;
|
||||||
|
|
||||||
virtual Status Lookup(int64 uid, int proto_index,
|
virtual Status Lookup(int64 uid, int proto_index,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry) {
|
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||||
return Lookup(uid, proto_index, std::move(entry),
|
return Lookup(uid, proto_index, std::move(entry),
|
||||||
CompilationCacheFetchTarget::MAIN);
|
CompilationCacheFetchTarget::MAIN);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Forward declaration to break cycle dependency graph.
|
||||||
|
class TpuCompilationCacheExternal;
|
||||||
|
|
||||||
// Class for looking up ISA protos when the execute and compile Op are in the
|
// Class for looking up ISA protos when the execute and compile Op are in the
|
||||||
// same address space. The proto is simply looked up in the compilation cache,
|
// same address space. The proto is simply looked up in the compilation cache,
|
||||||
// without any serialization taking place.
|
// without any serialization taking place.
|
||||||
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
||||||
public:
|
public:
|
||||||
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
|
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheExternal* cache);
|
||||||
~TpuCompilationCacheLocalLookup() override;
|
~TpuCompilationCacheLocalLookup() override;
|
||||||
|
|
||||||
Status Lookup(const string& proto_key,
|
Status Lookup(const string& proto_key,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) override;
|
CompilationCacheFetchTarget fetch_target) override;
|
||||||
|
|
||||||
Status Lookup(int64 uid, int proto_index,
|
Status Lookup(int64 uid, int proto_index,
|
||||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) override;
|
CompilationCacheFetchTarget fetch_target) override;
|
||||||
|
|
||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
@ -94,10 +90,10 @@ class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
|||||||
private:
|
private:
|
||||||
// The subgraph compilation cache, in the same process address space where the
|
// The subgraph compilation cache, in the same process address space where the
|
||||||
// lookups are happening.
|
// lookups are happening.
|
||||||
TpuCompilationCacheInterface* cache_;
|
TpuCompilationCacheExternal* cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_
|
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||||
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||||
|
Loading…
Reference in New Issue
Block a user