Refactor TpuCompilationCacheEntry interface to return TpuProgramGroupInterface and core_index and makes CacheEntry less transparent and move application specific logics outside of cache.
PiperOrigin-RevId: 324705343 Change-Id: I9dc421df069dbe7dc9bb57695f06e8b636fbc945
This commit is contained in:
parent
9474df4a12
commit
3cf7683cfe
@ -92,8 +92,6 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":tpu_compilation_cache_factory",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_local_lookup",
|
||||
":tpu_compilation_cache_lookup",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_op_consts",
|
||||
"//tensorflow/c:tf_status",
|
||||
@ -210,14 +208,30 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_entry",
|
||||
srcs = ["tpu_compilation_cache_entry.cc"],
|
||||
hdrs = [
|
||||
"tpu_compilation_cache_entry.h",
|
||||
],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_executable_info_proto_cc",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_program_group",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/platform:casts",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_entry_impl",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_compilation_cache_entry_impl.h"],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_executable_info_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
@ -288,8 +302,6 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -329,7 +341,6 @@ cc_library(
|
||||
hdrs = ["tpu_compilation_cache_interface.h"],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_compilation_metrics_hdrs",
|
||||
@ -361,6 +372,7 @@ cc_library(
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_entry_impl",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
@ -370,7 +382,6 @@ cc_library(
|
||||
":tpu_compile_op_support",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_op_consts",
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_program_group",
|
||||
":tpu_util",
|
||||
":trace_util_hdrs",
|
||||
@ -380,10 +391,10 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -604,7 +615,6 @@ cc_library(
|
||||
deps = [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_external",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_local_lookup",
|
||||
":tpu_compilation_cache_lookup",
|
||||
":tpu_executable_info_proto_cc",
|
||||
|
||||
54
tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
Normal file
54
tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
TpuCompilationCacheEntry::TpuCompilationCacheEntry(
|
||||
const TpuProgramGroupInterface* tpu_program_group, int core_index)
|
||||
: tpu_program_group_(
|
||||
tensorflow::down_cast<const TpuProgramGroup*>(tpu_program_group)),
|
||||
core_index_(core_index) {}
|
||||
|
||||
// Constructor for an empty entry.
|
||||
TpuCompilationCacheEntry::TpuCompilationCacheEntry()
|
||||
: tpu_program_group_(nullptr) {}
|
||||
|
||||
const TPUExecutableInfoProto* TpuCompilationCacheEntry::get_executable_info()
|
||||
const {
|
||||
return &(tpu_program_group_->executable_info());
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto*
|
||||
TpuCompilationCacheEntry::get_host_transfer_info() const {
|
||||
return &(tpu_program_group_->host_transfer_info());
|
||||
}
|
||||
|
||||
const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const {
|
||||
return tpu_program_group_->hlo_metadatas()[core_index_];
|
||||
}
|
||||
|
||||
// TODO(henrytan,jiawenhao): When should we expect more than one
|
||||
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
|
||||
const XLA_TpuProgram* TpuCompilationCacheEntry::get_tpu_program() const {
|
||||
CHECK_EQ(tpu_program_group_->program_count(), 1);
|
||||
return tpu_program_group_->tpu_programs()[core_index_];
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
@ -18,32 +18,30 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to
|
||||
// fetch a TPU program for a given TPU core index.
|
||||
// A version of `CompilationCacheEntry` to access Tpu binary program
|
||||
// `XLA_TpuProgram`.
|
||||
class TpuCompilationCacheEntry {
|
||||
public:
|
||||
explicit TpuCompilationCacheEntry(
|
||||
const TpuProgramGroupInterface* tpu_program_group, int core_index)
|
||||
: tpu_program_group_(tpu_program_group), core_index_(core_index) {}
|
||||
|
||||
const TpuProgramGroupInterface* tpu_program_group, int core_index);
|
||||
// Constructor for an empty entry.
|
||||
TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {}
|
||||
|
||||
const TpuProgramGroupInterface* tpu_program_group() const {
|
||||
return tpu_program_group_;
|
||||
}
|
||||
|
||||
int core_index() const { return core_index_; }
|
||||
TpuCompilationCacheEntry();
|
||||
const TPUExecutableInfoProto* get_executable_info() const;
|
||||
const TPUHostTransferInfoProto* get_host_transfer_info() const;
|
||||
const xla::HloProto* get_hlo_metadata() const;
|
||||
// TODO(henrytan): maybe nicer to return C++ wrapper of `XLA_TpuProgram`
|
||||
const XLA_TpuProgram* get_tpu_program() const;
|
||||
|
||||
private:
|
||||
const TpuProgramGroupInterface* tpu_program_group_;
|
||||
const TpuProgramGroup* tpu_program_group_;
|
||||
int core_index_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
// Wrapper for a cache entry that holds a reference to the entry until the
|
||||
// wrapper is deleted. This wrapper is the concrete type of
|
||||
// CompilationCacheEntryRef returned by Lookup.
|
||||
template <typename CacheEntryType>
|
||||
class CompilationCacheEntryRefImpl
|
||||
: public CompilationCacheEntryRef<CacheEntryType> {
|
||||
public:
|
||||
CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
|
||||
CompiledSubgraph* entry, int index);
|
||||
~CompilationCacheEntryRefImpl() override;
|
||||
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
protected:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
// A reference to entry_ is acquired in the constructor and released via
|
||||
// parent->DiscardEntryRefs in the destructor.
|
||||
CompiledSubgraph* entry_;
|
||||
// The index of the program in entry_ that is returned by the get method.
|
||||
int index_;
|
||||
};
|
||||
template <typename CacheEntryType>
|
||||
CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
|
||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||
: parent_(parent), entry_(entry), index_(index) {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
entry_->Ref();
|
||||
} else {
|
||||
// This is a sharding/unsharding entry nested in a main entry. Only
|
||||
// refcount the main entry.
|
||||
entry_->main_entry->Ref();
|
||||
}
|
||||
}
|
||||
template <typename CacheEntryType>
|
||||
CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
} else {
|
||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||
}
|
||||
}
|
||||
template <typename CacheEntryType>
|
||||
Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
CompiledSubgraph* target = nullptr;
|
||||
switch (fetch_target) {
|
||||
case CompilationCacheFetchTarget::MAIN:
|
||||
target = entry_;
|
||||
break;
|
||||
case CompilationCacheFetchTarget::SHARDING:
|
||||
target = entry_->sharding_entry.get();
|
||||
break;
|
||||
case CompilationCacheFetchTarget::UNSHARDING:
|
||||
target = entry_->unsharding_entry.get();
|
||||
break;
|
||||
default:
|
||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||
}
|
||||
if (target == nullptr) {
|
||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||
// with nullptr.
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
}
|
||||
// Otherwise, since the refcount is always on the main entry, we don't
|
||||
// need ref/unref.
|
||||
entry_ = target;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||
@ -16,18 +16,15 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||
|
||||
@ -51,22 +48,23 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
|
||||
entry->tpu_program_group =
|
||||
absl::make_unique<TpuProgramGroup>(std::move(tpu_program_group));
|
||||
entry->initialized = true;
|
||||
|
||||
if (entry->initialization_status.ok()) {
|
||||
// Compute the entries total size once all members are initialized.
|
||||
entry->total_size = entry->ComputeTotalSize();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<CompiledSubgraph> CreateAndInitializeCompiledSubgraph(
|
||||
CompiledSubgraph* main_entry) {
|
||||
auto entry = absl::make_unique<CompiledSubgraph>();
|
||||
entry->main_entry = main_entry;
|
||||
entry->tpu_program_group = absl::make_unique<TpuProgramGroup>();
|
||||
return entry;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
|
||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||
: CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
|
||||
index) {}
|
||||
|
||||
TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
|
||||
if (entry_ == nullptr) {
|
||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||
// non-existing sharding/unsharding entries.
|
||||
return TpuCompilationCacheEntry();
|
||||
}
|
||||
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
||||
}
|
||||
|
||||
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
|
||||
@ -75,6 +73,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
||||
main_entry->parent = this;
|
||||
main_entry->subgraph_key = key;
|
||||
main_entry->uid = get_uid();
|
||||
// TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
|
||||
main_entry->cache_entry_debug_string = subgraph_key.prefix;
|
||||
VLOG(1) << "Cache Initializing Entry Session Debug "
|
||||
<< main_entry->cache_entry_debug_string;
|
||||
@ -113,29 +112,17 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
||||
std::pair<int64, CompiledSubgraph*>(main_entry->uid, main_entry));
|
||||
CHECK(uid_inserted.second);
|
||||
|
||||
if (tpu_program_group.has_sharding_program()) {
|
||||
main_entry->sharding_entry =
|
||||
CreateAndInitializeCompiledSubgraph(main_entry);
|
||||
TpuProgramGroup sharding_programs;
|
||||
sharding_programs.Initialize(
|
||||
tpu_program_group.tpu_programs(TpuProgramShardingType::kSharding));
|
||||
PopulateEntry(key, main_entry->sharding_entry.get(),
|
||||
std::move(sharding_programs));
|
||||
|
||||
main_entry->unsharding_entry =
|
||||
CreateAndInitializeCompiledSubgraph(main_entry);
|
||||
TpuProgramGroup unsharding_programs;
|
||||
unsharding_programs.Initialize(
|
||||
tpu_program_group.tpu_programs(TpuProgramShardingType::kUnsharding));
|
||||
PopulateEntry(key, main_entry->unsharding_entry.get(),
|
||||
std::move(unsharding_programs));
|
||||
if (initialization_status.ok()) {
|
||||
// Compute the entries total size once all members are initialized.
|
||||
main_entry->total_size = tpu_program_group.program_size();
|
||||
}
|
||||
|
||||
// TODO(henrytan): handle sharding/unsharding.
|
||||
PopulateEntry(key, main_entry, std::move(tpu_program_group));
|
||||
|
||||
for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
|
||||
auto entry_inserted = entries_by_proto_key_.insert(
|
||||
std::pair<std::string, std::pair<CompiledSubgraph*, int>>(
|
||||
std::pair<string, std::pair<CompiledSubgraph*, int>>(
|
||||
main_entry->proto_key[i], std::make_pair(main_entry, i)));
|
||||
CHECK(entry_inserted.second);
|
||||
}
|
||||
|
||||
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
@ -45,6 +46,17 @@ namespace tpu {
|
||||
|
||||
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
||||
public:
|
||||
using Status = ::stream_executor::port::Status;
|
||||
|
||||
class EntryRefImpl
|
||||
: public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
|
||||
public:
|
||||
EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
|
||||
int index);
|
||||
|
||||
TpuCompilationCacheEntry get() override;
|
||||
};
|
||||
|
||||
explicit TpuCompilationCacheExternal(int64 max_cache_size)
|
||||
: TpuCompilationCacheInterface(max_cache_size) {}
|
||||
|
||||
|
||||
@ -38,77 +38,10 @@ void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
|
||||
entries_.push_back(entry);
|
||||
}
|
||||
|
||||
std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
|
||||
string TpuCompilationCacheInterface::RefHolder::DebugString() const {
|
||||
return "TpuCompilationCacheRefHolder";
|
||||
}
|
||||
|
||||
CompilationCacheEntryRef::CompilationCacheEntryRef()
|
||||
: parent_(nullptr), entry_(nullptr), index_(0) {}
|
||||
|
||||
CompilationCacheEntryRef::CompilationCacheEntryRef(
|
||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||
: parent_(parent), entry_(entry), index_(index) {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
entry_->Ref();
|
||||
} else {
|
||||
// This is a sharding/unsharding entry nested in a main entry. Only
|
||||
// refcount the main entry.
|
||||
entry_->main_entry->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
CompilationCacheEntryRef::~CompilationCacheEntryRef() {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
} else {
|
||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||
}
|
||||
}
|
||||
|
||||
TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
|
||||
if (entry_ == nullptr) {
|
||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||
// non-existing sharding/unsharding entries.
|
||||
return TpuCompilationCacheEntry();
|
||||
}
|
||||
|
||||
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
||||
}
|
||||
|
||||
Status CompilationCacheEntryRef::ToSubEntryRef(
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
CompiledSubgraph* target = nullptr;
|
||||
switch (fetch_target) {
|
||||
case CompilationCacheFetchTarget::MAIN:
|
||||
target = entry_;
|
||||
break;
|
||||
case CompilationCacheFetchTarget::SHARDING:
|
||||
target = entry_->sharding_entry.get();
|
||||
break;
|
||||
case CompilationCacheFetchTarget::UNSHARDING:
|
||||
target = entry_->unsharding_entry.get();
|
||||
break;
|
||||
default:
|
||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||
}
|
||||
|
||||
if (target == nullptr) {
|
||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||
// with nullptr.
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
}
|
||||
// Otherwise, since the refcount is always on the main entry, we don't
|
||||
// need ref/unref.
|
||||
entry_ = target;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TpuCompilationCacheInterface::TpuCompilationCacheInterface(int64 max_cache_size)
|
||||
: max_cache_size_(max_cache_size) {
|
||||
CHECK_GE(max_cache_size_, 0);
|
||||
@ -223,7 +156,7 @@ void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
|
||||
entry->Unref();
|
||||
}
|
||||
|
||||
size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
|
||||
size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
|
||||
auto erased = cache_.erase(key);
|
||||
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
|
||||
|
||||
@ -263,7 +196,7 @@ CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
|
||||
}
|
||||
erased = entries_by_uid_.erase(entry->uid);
|
||||
CHECK_EQ(erased, 1);
|
||||
for (const std::string& key : entry->proto_key) {
|
||||
for (const string& key : entry->proto_key) {
|
||||
erased = entries_by_proto_key_.erase(key);
|
||||
CHECK_EQ(erased, 1);
|
||||
}
|
||||
@ -336,10 +269,10 @@ void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
|
||||
}
|
||||
}
|
||||
|
||||
void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
|
||||
void TpuCompilationCacheInterface::InsertEntry(const string& key,
|
||||
CompiledSubgraph* entry) {
|
||||
auto cache_inserted =
|
||||
cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
|
||||
cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
|
||||
CHECK(cache_inserted.second);
|
||||
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
|
||||
|
||||
@ -362,8 +295,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
||||
std::vector<CompiledSubgraph*> removed_entries;
|
||||
@ -376,7 +308,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
||||
return status;
|
||||
}
|
||||
|
||||
std::string TpuCompilationCacheInterface::FindCacheKey(
|
||||
string TpuCompilationCacheInterface::FindCacheKey(
|
||||
const TpuCompilationCacheKey& subgraph_key) {
|
||||
if (!subgraph_key.has_guaranteed_const) {
|
||||
return subgraph_key.prefix;
|
||||
@ -399,8 +331,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
||||
@ -414,18 +345,17 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
// for the lifetime of the object, see InitializeEntry() call below.
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
std::string cache_key = FindCacheKey(subgraph_key);
|
||||
string cache_key = FindCacheKey(subgraph_key);
|
||||
auto iter = cache_.find(cache_key);
|
||||
bool is_new_key = iter == cache_.end();
|
||||
|
||||
const std::string session_name =
|
||||
tpu::SessionNameFromMetadata(session_metadata);
|
||||
const string session_name = tpu::SessionNameFromMetadata(session_metadata);
|
||||
|
||||
if (is_new_key) {
|
||||
cache_key = subgraph_key.ToString();
|
||||
TpuCompilationMetrics::IncrementCacheLookupCount(
|
||||
/*is_cache_hit=*/false, session_name);
|
||||
const std::string msg =
|
||||
const string msg =
|
||||
strings::StrCat("TPU host compilation cache miss: cache_key(",
|
||||
cache_key, "), session_name(", session_name, ")");
|
||||
TRACESTRING(msg);
|
||||
@ -434,7 +364,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
// Check if caller has disabled compilation. Set using
|
||||
// internal::ScopedTpuCompileDisabler.
|
||||
if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
|
||||
const std::string error_msg = strings::StrCat(
|
||||
const string error_msg = strings::StrCat(
|
||||
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
||||
"disabled, session_name(",
|
||||
session_name, ") Debug String: ", subgraph_key.debug_string);
|
||||
@ -473,7 +403,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
} else {
|
||||
TpuCompilationMetrics::IncrementCacheLookupCount(
|
||||
/*is_cache_hit=*/true, session_name);
|
||||
const std::string msg =
|
||||
const string msg =
|
||||
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
|
||||
"), session_name(", session_name, ")");
|
||||
TRACESTRING(msg);
|
||||
@ -536,8 +466,8 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
return entry->initialization_status;
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::GetKeysFromUid(
|
||||
int64 uid, std::vector<std::string>* keys) {
|
||||
Status TpuCompilationCacheInterface::GetKeysFromUid(int64 uid,
|
||||
std::vector<string>* keys) {
|
||||
keys->clear();
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
@ -549,49 +479,5 @@ Status TpuCompilationCacheInterface::GetKeysFromUid(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme(
|
||||
"TPU compilation cache proto lookup by uid",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_uid_.find(uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No subgraph found for uid ", uid);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second;
|
||||
if (proto_index < 0 ||
|
||||
proto_index >= cache_entry->tpu_program_group->program_count()) {
|
||||
return errors::NotFound("No proto found for core index ", proto_index,
|
||||
" in subgraph with uid ", uid);
|
||||
}
|
||||
*entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
|
||||
proto_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
const std::string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||
if (iter == entries_by_proto_key_.end()) {
|
||||
return errors::NotFound("No proto found for key ", proto_key);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second.first;
|
||||
int proto_index = iter->second.second;
|
||||
*entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
|
||||
proto_index);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||
@ -49,20 +48,18 @@ class CompilationRefHolder : public ResourceBase {
|
||||
~CompilationRefHolder() override = default;
|
||||
};
|
||||
|
||||
// Wrapper for a cache entry returned by all the TpuCompilationCacheInterface
|
||||
// `Lookup` methods, and ensures the underlying proto is not garbage-collected
|
||||
// until the client discards the ptr.
|
||||
// Base class for a reference to a cached tpu program. A unique_ptr to a
|
||||
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
||||
// and ensures the underlying proto is not garbage-collected until the client
|
||||
// discards the ptr.
|
||||
template <typename CacheEntryType>
|
||||
class CompilationCacheEntryRef {
|
||||
public:
|
||||
CompilationCacheEntryRef();
|
||||
CompilationCacheEntryRef(TpuCompilationCacheInterface* parent,
|
||||
CompiledSubgraph* entry, int index);
|
||||
virtual ~CompilationCacheEntryRef() = default;
|
||||
|
||||
virtual ~CompilationCacheEntryRef();
|
||||
|
||||
// Returns a TpuCompilationCacheEntry that should not be used beyond the
|
||||
// lifetime of the CompilationCacheEntryRef.
|
||||
virtual TpuCompilationCacheEntry get();
|
||||
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
||||
// of the tpu::CompilationCacheEntryRef.
|
||||
virtual CacheEntryType get() = 0;
|
||||
|
||||
// Mutates this ref to point to the entry's subentry (for
|
||||
// sharding/unsharding) or main entry (unchanged) as specified by
|
||||
@ -72,15 +69,7 @@ class CompilationCacheEntryRef {
|
||||
//
|
||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||
// entry, and the original entry will be unref'ed.
|
||||
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
|
||||
|
||||
protected:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
// A reference to entry_ is acquired in the constructor and released via
|
||||
// parent->DiscardEntryRefs in the destructor.
|
||||
CompiledSubgraph* entry_;
|
||||
// The index of the program in entry_ that is returned by the get method.
|
||||
int index_;
|
||||
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
|
||||
};
|
||||
|
||||
class TpuCompilationCacheInterface : public ResourceBase {
|
||||
@ -108,8 +97,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||
|
||||
@ -136,18 +124,19 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by key. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
Status Lookup(const std::string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by uid. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||
std::unique_ptr<CacheEntryRef>* entry);
|
||||
|
||||
// Looks up the subgraph represented by uid, and returns the vector of keys,
|
||||
// one per core, corresponding to that subgraph.
|
||||
Status GetKeysFromUid(int64 uid, std::vector<std::string>* keys);
|
||||
Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
|
||||
|
||||
// Makes a reference holder for this cache, that can be stored in the per-step
|
||||
// resource manager and will ensure that compiled entries persist until the
|
||||
@ -181,7 +170,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
// parent_->DiscardEntryRefs.
|
||||
void AddRef(CompiledSubgraph* entry);
|
||||
|
||||
std::string DebugString() const override;
|
||||
string DebugString() const override;
|
||||
|
||||
private:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
@ -196,8 +185,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<std::string>* proto_key,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompiledSubgraph*>* removed_entries,
|
||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||
@ -242,14 +230,14 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Removes the entry with given key from cache.
|
||||
size_t RemoveEntry(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Inserts the given key and entry to cache.
|
||||
void InsertEntry(const std::string& key, CompiledSubgraph* entry)
|
||||
void InsertEntry(const string& key, CompiledSubgraph* entry)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the cache key matching given subgraph_key.
|
||||
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
|
||||
string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Creates a new entry by running initialize_programs and places it in the
|
||||
@ -259,7 +247,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
//
|
||||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
||||
virtual CompiledSubgraph* InitializeEntry(
|
||||
const std::string& key,
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgramGroupInterface*)>&
|
||||
initialize_programs,
|
||||
const TpuCompilationCacheKey& subgraph_key)
|
||||
@ -288,16 +276,13 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
// cache_ key matching a given subgraph key. When doing a lookup, check
|
||||
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
||||
// Map from key prefix + session_handle to a cache_ key.
|
||||
absl::node_hash_map<std::string, std::string> session_key_map_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
// Map from key prefix + fingerprint to a cache_ key.
|
||||
absl::node_hash_map<std::string, std::string> fingerprint_key_map_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
// All the subgraph entries that can be looked up in the cache. An entry is
|
||||
// marked for eviction iff it is present in cache_ and not in
|
||||
// entries_by_last_use_.
|
||||
std::unordered_map<std::string, CompiledSubgraph*> cache_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
|
||||
// All the subgraph entries that can be looked up in the cache, indexed by
|
||||
// uid.
|
||||
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
||||
@ -305,7 +290,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
// All the protos that can be looked up in the cache, indexed by proto
|
||||
// key. The value of the map is a subgraph and the index of the proto compiled
|
||||
// for that subgraph.
|
||||
std::unordered_map<std::string, std::pair<CompiledSubgraph*, int>>
|
||||
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
|
||||
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
||||
// Map from last_use to entry, used to mark entries for eviction in LRU
|
||||
// order. If an entry's last_use counter is not present as a key in
|
||||
@ -319,6 +304,50 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
||||
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
|
||||
delete;
|
||||
};
|
||||
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme(
|
||||
"TPU compilation cache proto lookup by uid",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_uid_.find(uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No subgraph found for uid ", uid);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second;
|
||||
if (proto_index < 0 ||
|
||||
proto_index >= cache_entry->tpu_program_group->program_count()) {
|
||||
return errors::NotFound("No proto found for core index ", proto_index,
|
||||
" in subgraph with uid ", uid);
|
||||
}
|
||||
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||
if (iter == entries_by_proto_key_.end()) {
|
||||
return errors::NotFound("No proto found for key ", proto_key);
|
||||
}
|
||||
CompiledSubgraph* cache_entry = iter->second.first;
|
||||
int proto_index = iter->second.second;
|
||||
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@ -16,50 +16,70 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace {
|
||||
class CompilationCacheFetchTargetUtility {
|
||||
public:
|
||||
CompilationCacheFetchTargetUtility()
|
||||
: names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
|
||||
|
||||
std::string name(CompilationCacheFetchTarget target) const {
|
||||
return names_[static_cast<int>(target)];
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<std::string> names_;
|
||||
};
|
||||
|
||||
std::string GetName(CompilationCacheFetchTarget target) {
|
||||
static const auto* util = new CompilationCacheFetchTargetUtility();
|
||||
return util->name(target);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
||||
TpuCompilationCacheInterface* cache)
|
||||
: cache_(cache) {
|
||||
cache_->Ref();
|
||||
}
|
||||
: cache_(cache) {}
|
||||
|
||||
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
||||
cache_->Unref();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
const string& proto_key,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
||||
/*level=*/2);
|
||||
Status s = cache_->Lookup(proto_key, entry);
|
||||
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
||||
proto_key, entry);
|
||||
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
||||
<< s;
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
||||
VLOG(1) << "Fetched subentry: "
|
||||
<< CompilationCacheFetchTarget_Name(fetch_target) << " with status "
|
||||
|
||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||
<< s;
|
||||
return s;
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
||||
/*level=*/2);
|
||||
Status s = cache_->Lookup(uid, proto_index, entry);
|
||||
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
||||
uid, proto_index, entry);
|
||||
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
||||
<< " in local subgraph cache status " << s;
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
||||
VLOG(1) << "Fetched subentry: "
|
||||
<< CompilationCacheFetchTarget_Name(fetch_target) << " with status "
|
||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||
<< s;
|
||||
return s;
|
||||
}
|
||||
@ -67,5 +87,6 @@ Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
string TpuCompilationCacheLocalLookup::DebugString() const {
|
||||
return "TpuCompilationCacheLocalLookup";
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -28,17 +28,24 @@ namespace tpu {
|
||||
// Class for looking up TPU programs when the execute and compile Op are in the
|
||||
// same address space. The proto is simply looked up in the compilation cache,
|
||||
// without any serialization taking place.
|
||||
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
||||
class TpuCompilationCacheLocalLookup
|
||||
: public TpuCompilationCacheLookup<
|
||||
CompilationCacheEntryRef<TpuCompilationCacheEntry>> {
|
||||
public:
|
||||
using TpuCompilationCacheEntryRef =
|
||||
::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
|
||||
using EntryRefImpl =
|
||||
::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
|
||||
|
||||
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
|
||||
~TpuCompilationCacheLocalLookup() override;
|
||||
|
||||
Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
@ -23,11 +23,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// TODO(b/162241759): consider merging TpuCompilationCacheLookup and
|
||||
// TpuCompilationCacheInterface.
|
||||
// Base class allowing Execute Ops to look up TPU programs. Different subclasses
|
||||
// are used when the execute Op is in the same address space as the compile Op,
|
||||
// and when they need to communicate over RPC.
|
||||
template <typename TpuCompilationCacheEntryRefType>
|
||||
class TpuCompilationCacheLookup : public ResourceBase {
|
||||
public:
|
||||
~TpuCompilationCacheLookup() override = default;
|
||||
@ -44,11 +43,12 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
||||
// fetch_target requests one of them, then after this call
|
||||
// (*entry)->get().get_executable() will return nullptr.
|
||||
virtual Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) = 0;
|
||||
|
||||
virtual Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
virtual Status Lookup(
|
||||
const string& proto_key,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry) {
|
||||
return Lookup(proto_key, std::move(entry),
|
||||
CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
@ -58,15 +58,17 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
||||
// returned in program. The wrapper is guaranteed to be valid only during the
|
||||
// execution of the Op requesting the proto.
|
||||
virtual Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) = 0;
|
||||
|
||||
virtual Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
virtual Status Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry) {
|
||||
return Lookup(uid, proto_index, std::move(entry),
|
||||
CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@ -413,6 +413,46 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuCompileOpKernelCommon::ComputeArgumentShapes(
|
||||
const tpu::TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes) {
|
||||
arg_shapes->resize(metadata.args_size());
|
||||
int dynamic_shape_pos = 0;
|
||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
||||
// The XLA compiler determines the shape of each constant by inspecting the
|
||||
// value of its corresponding host-memory tensor. As a result, we don't need
|
||||
// to give the compiler graph-inferred shapes for constant arguments.
|
||||
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
||||
PartialTensorShape static_shape(arg.shape());
|
||||
|
||||
TensorShape& shape = (*arg_shapes)[i];
|
||||
if (static_shape.IsFullyDefined()) {
|
||||
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
||||
} else {
|
||||
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
||||
<< "Too few dynamic shapes";
|
||||
shape = dynamic_shapes[dynamic_shape_pos++];
|
||||
if (!static_shape.IsCompatibleWith(shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch between static and dynamic shape for argument. Static "
|
||||
"shape: ",
|
||||
static_shape.DebugString(),
|
||||
"; dynamic shape: ", shape.DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Checks we consumed all of the dynamic shapes.
|
||||
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
||||
<< "Too many dynamic shapes";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function arguments and return values lose their device assignments, so we
|
||||
// must recreate them.
|
||||
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
|
||||
|
||||
@ -99,6 +99,15 @@ class TpuCompileOpKernelCommon {
|
||||
const std::vector<TensorShape>& arg_shapes,
|
||||
TpuProgramGroupInterface* tpu_program_group) = 0;
|
||||
|
||||
// Computes shapes for each argument. Uses both the static shape from the
|
||||
// metadata, and the dynamic shapes where the static shape is not
|
||||
// defined. There must be one dynamic_shape for each argument with a
|
||||
// partially defined shape, in index order.
|
||||
static Status ComputeArgumentShapes(
|
||||
const tpu::TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes);
|
||||
|
||||
// Performs shape inference on `computation`, filling shape_info with operator
|
||||
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
|
||||
static Status RunShapeInferenceOnComputation(
|
||||
|
||||
@ -540,43 +540,5 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes) {
|
||||
arg_shapes->resize(metadata.args_size());
|
||||
int dynamic_shape_pos = 0;
|
||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
||||
// The XLA compiler determines the shape of each constant by inspecting the
|
||||
// value of its corresponding host-memory tensor. As a result, we don't need
|
||||
// to give the compiler graph-inferred shapes for constant arguments.
|
||||
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
||||
PartialTensorShape static_shape(arg.shape());
|
||||
|
||||
TensorShape& shape = (*arg_shapes)[i];
|
||||
if (static_shape.IsFullyDefined()) {
|
||||
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
||||
} else {
|
||||
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
||||
<< "Too few dynamic shapes";
|
||||
shape = dynamic_shapes[dynamic_shape_pos++];
|
||||
if (!static_shape.IsCompatibleWith(shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch between static and dynamic shape for argument. Static "
|
||||
"shape: ",
|
||||
static_shape.DebugString(),
|
||||
"; dynamic shape: ", shape.DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Checks we consumed all of the dynamic shapes.
|
||||
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
||||
<< "Too many dynamic shapes";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -159,14 +159,6 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
||||
TPUCompileMetadataProto* metadata,
|
||||
NameAttrList* function_name,
|
||||
std::string* mlir_module);
|
||||
|
||||
// Computes shapes for each argument. Uses both the static shape from the
|
||||
// metadata, and the dynamic shapes where the static shape is not
|
||||
// defined. There must be one dynamic_shape for each argument with a
|
||||
// partially defined shape, in index order.
|
||||
Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@ -25,8 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
@ -255,10 +253,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
mesh_state_interface));
|
||||
}
|
||||
|
||||
VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
|
||||
rmgr, tpu::kCompiledProtoCacheResourceName));
|
||||
|
||||
if (enable_whole_mesh_compilations_) {
|
||||
// If this is a whole mesh compilation mode, create the compilation cache,
|
||||
// if missing.
|
||||
@ -282,13 +276,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
|
||||
if (local_compilation_cache != nullptr) {
|
||||
local_compilation_cache->Unref();
|
||||
|
||||
tpu::TpuCompilationCacheLookup* proto_lookup;
|
||||
proto_lookup =
|
||||
new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, rmgr->Create(rmgr->default_container(),
|
||||
tpu::kCompiledProtoCacheResourceName, proto_lookup));
|
||||
}
|
||||
|
||||
Tensor* ctx_output;
|
||||
|
||||
@ -40,12 +40,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
@ -58,10 +56,14 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
using ::tensorflow::tpu::CompilationCacheEntryRef;
|
||||
using ::tensorflow::tpu::TpuCompilationCacheLookup;
|
||||
|
||||
using ::tensorflow::tpu::TpuNodeContext;
|
||||
using CompilationCacheEntryRef = ::tensorflow::tpu::CompilationCacheEntryRef<
|
||||
::tensorflow::tpu::TpuCompilationCacheEntry>;
|
||||
using TpuCompilationCacheLookup =
|
||||
::tensorflow::tpu::TpuCompilationCacheLookup<CompilationCacheEntryRef>;
|
||||
|
||||
// Looks up the input `key` in the compilation cache, populating
|
||||
// `*rendezvous_key_base` and `*entry`.
|
||||
@ -639,35 +641,28 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
||||
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
|
||||
|
||||
string rendezvous_key_base;
|
||||
std::unique_ptr<CompilationCacheEntryRef> entry_ref;
|
||||
std::unique_ptr<CompilationCacheEntryRef> entry;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
|
||||
GetComputationCacheEntry(context, &rendezvous_key_base, &entry));
|
||||
|
||||
// Shapes of the inputs and outputs, in xla::Shape form.
|
||||
tpu::TpuCompilationCacheEntry entry = entry_ref->get();
|
||||
const tpu::TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<const tpu::TpuProgramGroup*>(
|
||||
entry.tpu_program_group());
|
||||
CHECK_NE(tpu_program_group, nullptr);
|
||||
const int core_index = entry.core_index();
|
||||
const TPUExecutableInfoProto& executable =
|
||||
tpu_program_group->executable_info(core_index);
|
||||
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
|
||||
|
||||
xla::Backend* const backend = node_context->backend();
|
||||
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
||||
TF_RET_CHECK(context->op_device_context());
|
||||
se::Stream* stream = context->op_device_context()->stream();
|
||||
|
||||
TF_RET_CHECK(executable.input_shapes_size() == 1);
|
||||
TF_RET_CHECK(proto->input_shapes_size() == 1);
|
||||
|
||||
xla::Shape host_shape(executable.input_shapes(0));
|
||||
xla::Shape host_shape(proto->input_shapes(0));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto variable_update_map,
|
||||
BuildVariableUpdateMap(executable.variable_indices(),
|
||||
BuildVariableUpdateMap(proto->variable_indices(),
|
||||
fused_device_var_reads_in_computation_inputs_,
|
||||
fused_device_var_updates_in_computation_outputs_,
|
||||
executable.output_tensor_shapes().size()));
|
||||
proto->output_tensor_shapes().size()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<InputBuffers> input_buffers,
|
||||
BuildComputationInputs(context, host_shape, variable_update_map, backend,
|
||||
@ -702,9 +697,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
||||
|
||||
// Snapshot the inputs, if a snapshot was requested.
|
||||
std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
|
||||
if (executable.has_session_module()) {
|
||||
hlo_snapshot =
|
||||
std::make_shared<xla::HloSnapshot>(executable.session_module());
|
||||
if (proto->has_session_module()) {
|
||||
hlo_snapshot = std::make_shared<xla::HloSnapshot>(proto->session_module());
|
||||
auto literal =
|
||||
std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
|
||||
transfer_manager->TransferLiteralFromDevice(
|
||||
@ -729,9 +723,9 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
||||
const uint32 rng_seed = GetXLARandomSeed();
|
||||
|
||||
std::unique_ptr<xla::DeviceAssignment> device_assignment;
|
||||
if (executable.has_device_assignment()) {
|
||||
if (proto->has_device_assignment()) {
|
||||
TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
|
||||
executable.device_assignment()));
|
||||
proto->device_assignment()));
|
||||
}
|
||||
|
||||
VLOG(4) << "Input buffers after alias resolution: "
|
||||
@ -749,24 +743,24 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
||||
// we free a memory and reassign it to other users while a program is running,
|
||||
// all subsequent writes to the program that could possibly clobber the memory
|
||||
// will depend on the program to finish.
|
||||
const TPUHostTransferInfoProto& host_transfer_info =
|
||||
tpu_program_group->host_transfer_info(core_index);
|
||||
const TPUHostTransferInfoProto* host_transfer_info =
|
||||
entry->get().get_host_transfer_info();
|
||||
const xla::HloProto* hlo_metadata = entry->get().get_hlo_metadata();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::ExecutionOutput output,
|
||||
TPUExecute(executable, host_transfer_info,
|
||||
*tpu_program_group->hlo_metadata(core_index), std::move(input),
|
||||
TPUExecute(*proto, *host_transfer_info, *hlo_metadata, std::move(input),
|
||||
rendezvous_key_base, rng_seed, node_context.get(),
|
||||
device_assignment.get(), context->cancellation_manager(),
|
||||
context, stream, transfer_stream_ptr.get(),
|
||||
tpu_program_group->tpu_program(core_index)));
|
||||
entry->get().get_tpu_program()));
|
||||
stream->ThenRecordEvent(definition_event.get());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<OutputBuffers> output_buffers,
|
||||
AllocateOutputTensors(
|
||||
context, output.ConsumeResult(), executable.output_tensor_shapes(),
|
||||
variable_update_map, node_context.get(), stream, device_ordinal,
|
||||
input_buffers.get(), definition_event));
|
||||
AllocateOutputTensors(context, output.ConsumeResult(),
|
||||
proto->output_tensor_shapes(), variable_update_map,
|
||||
node_context.get(), stream, device_ordinal,
|
||||
input_buffers.get(), definition_event));
|
||||
|
||||
// Transfer the outputs and save the snapshot to disk.
|
||||
if (hlo_snapshot) {
|
||||
|
||||
@ -21,9 +21,6 @@ limitations under the License.
|
||||
|
||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
|
||||
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
||||
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Creates a new TPU program.
|
||||
@ -67,15 +64,6 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
||||
|
||||
// Check if TPU program has sharding.
|
||||
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
||||
const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Gets TPU program by sharding type. Return value is valid only when the
|
||||
// `status.status()` returns `OK`.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
||||
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
||||
|
||||
struct TfTpu_TpuProgramApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
||||
@ -88,8 +76,6 @@ struct TfTpu_TpuProgramApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
@ -99,71 +98,55 @@ StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
|
||||
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
|
||||
per_core_variable_indices, device_assignment);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TpuProgramGroup::Initialize(
|
||||
absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
|
||||
Status CreateTpuProgramGroup(
|
||||
absl::Span<XLA_TpuProgram* const> xla_tpu_programs,
|
||||
TpuProgramGroupInterface* tpu_program_group_interface) {
|
||||
CHECK_GT(xla_tpu_programs.size(), 0);
|
||||
set_tpu_programs(xla_tpu_programs);
|
||||
TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||
CHECK_NE(tpu_program_group, nullptr);
|
||||
tpu_program_group->set_tpu_programs(xla_tpu_programs);
|
||||
|
||||
std::vector<bool> may_modify_variables_array(xla_tpu_programs.size(), false);
|
||||
std::vector<TPUExecutableInfoProto> executable_infos(xla_tpu_programs.size());
|
||||
std::vector<TPUHostTransferInfoProto> host_transfer_infos(
|
||||
xla_tpu_programs.size());
|
||||
std::vector<xla::HloProto> hlo_metadatas(xla_tpu_programs.size());
|
||||
for (size_t i = 0; i < xla_tpu_programs.size(); ++i) {
|
||||
const XLA_TpuProgram* xla_tpu_program = xla_tpu_programs[i];
|
||||
bool may_modify_variables;
|
||||
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
|
||||
xla_tpu_program, &may_modify_variables);
|
||||
may_modify_variables_array[i] = may_modify_variables;
|
||||
// TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
|
||||
bool may_modify_variables;
|
||||
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_programs[0],
|
||||
&may_modify_variables);
|
||||
tpu_program_group->set_may_modify_variables(
|
||||
std::vector<bool>(1, may_modify_variables));
|
||||
|
||||
TpuSerializedProto serialized_executable_info;
|
||||
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
|
||||
xla_tpu_program, &serialized_executable_info);
|
||||
TPUExecutableInfoProto executable_info =
|
||||
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
|
||||
serialized_executable_info);
|
||||
executable_infos[i] = executable_info;
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
|
||||
TpuSerializedProto serialized_executable_info;
|
||||
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
|
||||
xla_tpu_programs[0], &serialized_executable_info);
|
||||
TPUExecutableInfoProto executable_info =
|
||||
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
|
||||
serialized_executable_info);
|
||||
tpu_program_group->set_executable_info(executable_info);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
|
||||
|
||||
TPUHostTransferInfoProto host_transfer_info;
|
||||
TpuSerializedProto serialized_host_transfer_info;
|
||||
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
|
||||
xla_tpu_program, &serialized_host_transfer_info);
|
||||
if (serialized_host_transfer_info.size > 0) {
|
||||
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
|
||||
serialized_host_transfer_info);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
|
||||
}
|
||||
host_transfer_infos[i] = host_transfer_info;
|
||||
|
||||
TpuSerializedProto serialized_hlo_metadata;
|
||||
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_program,
|
||||
&serialized_hlo_metadata);
|
||||
xla::HloProto hlo_metadata =
|
||||
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
|
||||
hlo_metadatas[i] = hlo_metadata;
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
|
||||
TPUHostTransferInfoProto host_transfer_info;
|
||||
TpuSerializedProto serialized_host_transfer_info;
|
||||
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
|
||||
xla_tpu_programs[0], &serialized_host_transfer_info);
|
||||
if (serialized_host_transfer_info.size > 0) {
|
||||
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
|
||||
serialized_host_transfer_info);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
|
||||
}
|
||||
tpu_program_group->set_host_transfer_info(host_transfer_info);
|
||||
|
||||
may_modify_variables_ = may_modify_variables_array;
|
||||
executable_infos_ = executable_infos;
|
||||
host_transfer_infos_ = host_transfer_infos;
|
||||
hlo_metadatas_ = hlo_metadatas;
|
||||
RefreshHloMetadatasPtrs();
|
||||
TpuSerializedProto serialized_hlo_metadata;
|
||||
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_programs[0],
|
||||
&serialized_hlo_metadata);
|
||||
xla::HloProto hlo_metadata =
|
||||
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
|
||||
tpu_program_group->set_hlo_metadata(hlo_metadata);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TpuProgramGroup::has_sharding_program() const {
|
||||
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
|
||||
} // namespace
|
||||
|
||||
int64_t TpuProgramGroup::program_size() const {
|
||||
int64_t total_size = 0;
|
||||
@ -218,6 +201,12 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
||||
TF_RET_CHECK(per_core_output_shapes.size() ==
|
||||
per_core_variable_indices.size());
|
||||
|
||||
// TODO(henrytan): add an interface to TpuProgramGroupInterface to set
|
||||
// may_modify_variables.
|
||||
TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||
tpu_program_group->may_modify_variables_ = may_modify_variables;
|
||||
|
||||
// With shardable input/output pairs, XLA could generate separate
|
||||
// sharding/unsharding programs along with the main program. The
|
||||
// sharding/unsharding programs will be in nested entries of the AOT
|
||||
@ -232,20 +221,17 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
||||
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
|
||||
xla_tpu_programs.size() == metadata.num_cores_per_replica());
|
||||
|
||||
// TODO(henrytan): add an interface to TpuProgramGroupInterface to set
|
||||
// may_modify_variables.
|
||||
TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||
tpu_program_group->Initialize(xla_tpu_programs);
|
||||
tpu_program_group->may_modify_variables_ = may_modify_variables;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateTpuProgramGroup(xla_tpu_programs, tpu_program_group));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
|
||||
: may_modify_variables_(std::move(other.may_modify_variables_)),
|
||||
host_compute_metadata_(std::move(other.host_compute_metadata_)),
|
||||
tpu_programs_(std::move(other.tpu_programs_)),
|
||||
executable_infos_(std::move(other.executable_infos_)),
|
||||
host_transfer_infos_(std::move(other.host_transfer_infos_)),
|
||||
executable_info_(std::move(other.executable_info_)),
|
||||
host_transfer_info_(std::move(other.host_transfer_info_)),
|
||||
hlo_metadatas_(std::move(other.hlo_metadatas_)) {
|
||||
RefreshHloMetadatasPtrs();
|
||||
}
|
||||
@ -262,12 +248,6 @@ absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
|
||||
return hlo_metadatas_ptrs_;
|
||||
}
|
||||
|
||||
const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, hlo_metadatas_ptrs_.size());
|
||||
return hlo_metadatas_ptrs_[index];
|
||||
}
|
||||
|
||||
void TpuProgramGroup::RefreshHloMetadatasPtrs() {
|
||||
hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
|
||||
for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
|
||||
@ -282,47 +262,6 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<bool>& TpuProgramGroup::may_modify_variables() const {
|
||||
return may_modify_variables_;
|
||||
}
|
||||
|
||||
void TpuProgramGroup::set_may_modify_variables(
|
||||
const std::vector<bool>& may_modify_variables) {
|
||||
may_modify_variables_ = may_modify_variables;
|
||||
}
|
||||
|
||||
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
|
||||
return tpu_programs_;
|
||||
}
|
||||
|
||||
const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
return tpu_programs_[index];
|
||||
}
|
||||
|
||||
void TpuProgramGroup::set_tpu_programs(
|
||||
absl::Span<XLA_TpuProgram* const> tpu_programs) {
|
||||
tpu_programs_.resize(tpu_programs.size());
|
||||
for (size_t i = 0; i < tpu_programs.size(); ++i) {
|
||||
tpu_programs_[i] = tpu_programs[i];
|
||||
}
|
||||
}
|
||||
|
||||
const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
|
||||
int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, executable_infos_.size());
|
||||
return executable_infos_[index];
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
|
||||
int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, host_transfer_infos_.size());
|
||||
return host_transfer_infos_[index];
|
||||
}
|
||||
|
||||
/*static*/
|
||||
Status TpuProgramGroup::CompileAndBuild(
|
||||
const TpuCompilationRequestProto& compilation_request,
|
||||
@ -348,27 +287,15 @@ Status TpuProgramGroup::CompileAndBuild(
|
||||
TF_RET_CHECK(count == 1 ||
|
||||
count == compilation_request.metadata().num_cores_per_replica());
|
||||
|
||||
VLOG(1) << "Initialize TpuProgramGroup.";
|
||||
TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||
tpu_program_group->Initialize(
|
||||
absl::MakeConstSpan(&xla_tpu_programs[0], count));
|
||||
VLOG(1) << "CreateTpuProgramGroup";
|
||||
Status serialize_status =
|
||||
CreateTpuProgramGroup(absl::MakeConstSpan(&xla_tpu_programs[0], count),
|
||||
tpu_program_group_interface);
|
||||
VLOG(1) << absl::StrCat("Run CreateTpuProgramGroup completed. StatusCode: ",
|
||||
serialize_status.code());
|
||||
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
|
||||
return status.status();
|
||||
return serialize_status;
|
||||
}
|
||||
|
||||
std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
|
||||
TpuProgramShardingType sharding_type) const {
|
||||
std::vector<XLA_TpuProgram*> tpu_programs;
|
||||
tpu_programs.reserve(tpu_programs_.size());
|
||||
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
|
||||
if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
|
||||
tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn(
|
||||
tpu_programs_[i], sharding_type));
|
||||
CHECK_NE(tpu_programs[i], nullptr);
|
||||
}
|
||||
}
|
||||
return tpu_programs;
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -102,16 +102,11 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
||||
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
||||
TpuProgramGroupInterface* tpu_program_group_interface);
|
||||
|
||||
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
|
||||
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
|
||||
|
||||
TpuProgramGroup() = default;
|
||||
TpuProgramGroup(TpuProgramGroup&& other);
|
||||
TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
|
||||
|
||||
bool has_sharding_program() const override;
|
||||
|
||||
size_t program_count() const override;
|
||||
size_t program_count() const override { return tpu_programs_.size(); }
|
||||
|
||||
int64_t program_size() const override;
|
||||
|
||||
@ -122,29 +117,58 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
||||
Status LogCompilationStats(const TpuCompilationCacheKey& key,
|
||||
absl::Duration duration) override;
|
||||
|
||||
const std::vector<bool>& may_modify_variables() const override;
|
||||
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
|
||||
const std::vector<bool>& may_modify_variables() const override {
|
||||
return may_modify_variables_;
|
||||
}
|
||||
void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
|
||||
may_modify_variables_ = may_modify_variables;
|
||||
}
|
||||
|
||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
|
||||
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
|
||||
const XLA_TpuProgram* tpu_program(int index) const;
|
||||
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
|
||||
const tf2xla::HostComputeMetadata& host_compute_metadata() const {
|
||||
return host_compute_metadata_;
|
||||
}
|
||||
void set_host_compute_metadata(
|
||||
const tf2xla::HostComputeMetadata& host_compute_metadata) {
|
||||
host_compute_metadata_ = host_compute_metadata;
|
||||
}
|
||||
|
||||
const TPUExecutableInfoProto& executable_info(int index) const;
|
||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const {
|
||||
return tpu_programs_;
|
||||
}
|
||||
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs) {
|
||||
tpu_programs_.resize(tpu_programs.size());
|
||||
for (size_t i = 0; i < tpu_programs.size(); ++i) {
|
||||
tpu_programs_[i] = tpu_programs[i];
|
||||
}
|
||||
}
|
||||
|
||||
const TPUExecutableInfoProto& executable_info() const {
|
||||
return executable_info_;
|
||||
}
|
||||
void set_executable_info(const TPUExecutableInfoProto& executable_info) {
|
||||
executable_info_ = executable_info;
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto& host_transfer_info() const {
|
||||
return host_transfer_info_;
|
||||
}
|
||||
void set_host_transfer_info(
|
||||
const TPUHostTransferInfoProto& host_transfer_info) {
|
||||
host_transfer_info_ = host_transfer_info;
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto& host_transfer_info(int index) const;
|
||||
void set_hlo_metadata(const xla::HloProto& hlo_metadata);
|
||||
const xla::HloProto* hlo_metadata(int index) const;
|
||||
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
||||
|
||||
private:
|
||||
void RefreshHloMetadatasPtrs();
|
||||
|
||||
std::vector<bool> may_modify_variables_;
|
||||
tf2xla::HostComputeMetadata host_compute_metadata_;
|
||||
|
||||
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
|
||||
std::vector<TPUExecutableInfoProto> executable_infos_;
|
||||
std::vector<TPUHostTransferInfoProto> host_transfer_infos_;
|
||||
TPUExecutableInfoProto executable_info_;
|
||||
TPUHostTransferInfoProto host_transfer_info_;
|
||||
|
||||
// To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
|
||||
// signature, we store HloProto values in hlo_metadatas_ when
|
||||
|
||||
@ -20,8 +20,6 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -36,16 +34,13 @@ class TpuProgramGroupInterface {
|
||||
public:
|
||||
virtual ~TpuProgramGroupInterface() = default;
|
||||
|
||||
// Check if whether sharding/unsharding program exists.
|
||||
virtual bool has_sharding_program() const = 0;
|
||||
|
||||
// Computes program count.
|
||||
virtual size_t program_count() const = 0;
|
||||
|
||||
// Computes total program size.
|
||||
virtual int64_t program_size() const = 0;
|
||||
|
||||
// Unloads and destroys safely TPU programs.
|
||||
// Unloads and destroys safely Tpu programs.
|
||||
virtual void UnloadAndDestroyPrograms() = 0;
|
||||
|
||||
// Logs program memory summary.
|
||||
|
||||
@ -64,8 +64,6 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user