Refactor TpuCompilationCacheEntry interface to return TpuProgramGroupInterface and core_index and makes CacheEntry less transparent and move application specific logics outside of cache.

PiperOrigin-RevId: 324705343
Change-Id: I9dc421df069dbe7dc9bb57695f06e8b636fbc945
This commit is contained in:
Russell Power 2020-08-03 16:16:10 -07:00 committed by TensorFlower Gardener
parent 9474df4a12
commit 3cf7683cfe
22 changed files with 523 additions and 509 deletions

View File

@ -92,8 +92,6 @@ tf_kernel_library(
deps = [
":tpu_compilation_cache_factory",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_local_lookup",
":tpu_compilation_cache_lookup",
":tpu_mesh_state_interface",
":tpu_op_consts",
"//tensorflow/c:tf_status",
@ -210,14 +208,30 @@ cc_library(
cc_library(
name = "tpu_compilation_cache_entry",
srcs = ["tpu_compilation_cache_entry.cc"],
hdrs = [
"tpu_compilation_cache_entry.h",
],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_proto_cc",
":tpu_executable_info_proto_cc",
":tpu_program_group_interface",
":tpu_program_group",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/platform:casts",
],
)
cc_library(
name = "tpu_compilation_cache_entry_impl",
srcs = [],
hdrs = ["tpu_compilation_cache_entry_impl.h"],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_interface",
":tpu_executable_info_proto_cc",
],
)
@ -288,8 +302,6 @@ cc_library(
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core/lib/core:status",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
],
)
@ -329,7 +341,6 @@ cc_library(
hdrs = ["tpu_compilation_cache_interface.h"],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_key",
":tpu_compilation_cache_proto_cc",
":tpu_compilation_metrics_hdrs",
@ -361,6 +372,7 @@ cc_library(
deps = [
":compiled_subgraph",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_entry_impl",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_key",
":tpu_compilation_cache_proto_cc",
@ -370,7 +382,6 @@ cc_library(
":tpu_compile_op_support",
":tpu_mesh_state_interface",
":tpu_op_consts",
":tpu_program_c_api_hdrs",
":tpu_program_group",
":tpu_util",
":trace_util_hdrs",
@ -380,10 +391,10 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
@ -604,7 +615,6 @@ cc_library(
deps = [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_external",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_local_lookup",
":tpu_compilation_cache_lookup",
":tpu_executable_info_proto_cc",

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
namespace tpu {
TpuCompilationCacheEntry::TpuCompilationCacheEntry(
const TpuProgramGroupInterface* tpu_program_group, int core_index)
: tpu_program_group_(
tensorflow::down_cast<const TpuProgramGroup*>(tpu_program_group)),
core_index_(core_index) {}
// Constructor for an empty entry.
TpuCompilationCacheEntry::TpuCompilationCacheEntry()
: tpu_program_group_(nullptr) {}
const TPUExecutableInfoProto* TpuCompilationCacheEntry::get_executable_info()
const {
return &(tpu_program_group_->executable_info());
}
const TPUHostTransferInfoProto*
TpuCompilationCacheEntry::get_host_transfer_info() const {
return &(tpu_program_group_->host_transfer_info());
}
const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const {
return tpu_program_group_->hlo_metadatas()[core_index_];
}
// TODO(henrytan,jiawenhao): When should we expect more than one
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
const XLA_TpuProgram* TpuCompilationCacheEntry::get_tpu_program() const {
CHECK_EQ(tpu_program_group_->program_count(), 1);
return tpu_program_group_->tpu_programs()[core_index_];
}
} // namespace tpu
} // namespace tensorflow

View File

@ -18,32 +18,30 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
namespace tensorflow {
namespace tpu {
// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to
// fetch a TPU program for a given TPU core index.
// A version of `CompilationCacheEntry` to access Tpu binary program
// `XLA_TpuProgram`.
class TpuCompilationCacheEntry {
public:
explicit TpuCompilationCacheEntry(
const TpuProgramGroupInterface* tpu_program_group, int core_index)
: tpu_program_group_(tpu_program_group), core_index_(core_index) {}
const TpuProgramGroupInterface* tpu_program_group, int core_index);
// Constructor for an empty entry.
TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {}
const TpuProgramGroupInterface* tpu_program_group() const {
return tpu_program_group_;
}
int core_index() const { return core_index_; }
TpuCompilationCacheEntry();
const TPUExecutableInfoProto* get_executable_info() const;
const TPUHostTransferInfoProto* get_host_transfer_info() const;
const xla::HloProto* get_hlo_metadata() const;
// TODO(henrytan): maybe nicer to return C++ wrapper of `XLA_TpuProgram`
const XLA_TpuProgram* get_tpu_program() const;
private:
const TpuProgramGroupInterface* tpu_program_group_;
const TpuProgramGroup* tpu_program_group_;
int core_index_;
};
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,94 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
namespace tensorflow {
namespace tpu {
// Wrapper for a cache entry that holds a reference to the entry until the
// wrapper is deleted. This wrapper is the concrete type of
// CompilationCacheEntryRef returned by Lookup.
template <typename CacheEntryType>
class CompilationCacheEntryRefImpl
: public CompilationCacheEntryRef<CacheEntryType> {
public:
CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
CompiledSubgraph* entry, int index);
~CompilationCacheEntryRefImpl() override;
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
protected:
TpuCompilationCacheInterface* parent_; // Not owned.
// A reference to entry_ is acquired in the constructor and released via
// parent->DiscardEntryRefs in the destructor.
CompiledSubgraph* entry_;
// The index of the program in entry_ that is returned by the get method.
int index_;
};
template <typename CacheEntryType>
CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
: parent_(parent), entry_(entry), index_(index) {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
entry_->Ref();
} else {
// This is a sharding/unsharding entry nested in a main entry. Only
// refcount the main entry.
entry_->main_entry->Ref();
}
}
template <typename CacheEntryType>
CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
parent_->DiscardEntryRefs({entry_});
} else {
parent_->DiscardEntryRefs({entry_->main_entry});
}
}
template <typename CacheEntryType>
Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
CompilationCacheFetchTarget fetch_target) {
CompiledSubgraph* target = nullptr;
switch (fetch_target) {
case CompilationCacheFetchTarget::MAIN:
target = entry_;
break;
case CompilationCacheFetchTarget::SHARDING:
target = entry_->sharding_entry.get();
break;
case CompilationCacheFetchTarget::UNSHARDING:
target = entry_->unsharding_entry.get();
break;
default:
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
}
if (target == nullptr) {
// Cache entry does not have an unsharding subentry. Unref and replace
// with nullptr.
parent_->DiscardEntryRefs({entry_});
}
// Otherwise, since the refcount is always on the main entry, we don't
// need ref/unref.
entry_ = target;
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_

View File

@ -16,18 +16,15 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/trace_util.h"
@ -51,22 +48,23 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
entry->tpu_program_group =
absl::make_unique<TpuProgramGroup>(std::move(tpu_program_group));
entry->initialized = true;
if (entry->initialization_status.ok()) {
// Compute the entries total size once all members are initialized.
entry->total_size = entry->ComputeTotalSize();
}
}
std::unique_ptr<CompiledSubgraph> CreateAndInitializeCompiledSubgraph(
CompiledSubgraph* main_entry) {
auto entry = absl::make_unique<CompiledSubgraph>();
entry->main_entry = main_entry;
entry->tpu_program_group = absl::make_unique<TpuProgramGroup>();
return entry;
}
} // namespace
TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
: CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
index) {}
TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
if (entry_ == nullptr) {
// Create an empty entry if the entry is nullptr. This corresponds to
// non-existing sharding/unsharding entries.
return TpuCompilationCacheEntry();
}
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
}
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
const string& key,
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
@ -75,6 +73,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
main_entry->parent = this;
main_entry->subgraph_key = key;
main_entry->uid = get_uid();
// TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
main_entry->cache_entry_debug_string = subgraph_key.prefix;
VLOG(1) << "Cache Initializing Entry Session Debug "
<< main_entry->cache_entry_debug_string;
@ -113,29 +112,17 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
std::pair<int64, CompiledSubgraph*>(main_entry->uid, main_entry));
CHECK(uid_inserted.second);
if (tpu_program_group.has_sharding_program()) {
main_entry->sharding_entry =
CreateAndInitializeCompiledSubgraph(main_entry);
TpuProgramGroup sharding_programs;
sharding_programs.Initialize(
tpu_program_group.tpu_programs(TpuProgramShardingType::kSharding));
PopulateEntry(key, main_entry->sharding_entry.get(),
std::move(sharding_programs));
main_entry->unsharding_entry =
CreateAndInitializeCompiledSubgraph(main_entry);
TpuProgramGroup unsharding_programs;
unsharding_programs.Initialize(
tpu_program_group.tpu_programs(TpuProgramShardingType::kUnsharding));
PopulateEntry(key, main_entry->unsharding_entry.get(),
std::move(unsharding_programs));
if (initialization_status.ok()) {
// Compute the entries total size once all members are initialized.
main_entry->total_size = tpu_program_group.program_size();
}
// TODO(henrytan): handle sharding/unsharding.
PopulateEntry(key, main_entry, std::move(tpu_program_group));
for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
auto entry_inserted = entries_by_proto_key_.insert(
std::pair<std::string, std::pair<CompiledSubgraph*, int>>(
std::pair<string, std::pair<CompiledSubgraph*, int>>(
main_entry->proto_key[i], std::make_pair(main_entry, i)));
CHECK(entry_inserted.second);
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
@ -45,6 +46,17 @@ namespace tpu {
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
public:
using Status = ::stream_executor::port::Status;
class EntryRefImpl
: public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
public:
EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
int index);
TpuCompilationCacheEntry get() override;
};
explicit TpuCompilationCacheExternal(int64 max_cache_size)
: TpuCompilationCacheInterface(max_cache_size) {}

View File

@ -38,77 +38,10 @@ void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
entries_.push_back(entry);
}
std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
string TpuCompilationCacheInterface::RefHolder::DebugString() const {
return "TpuCompilationCacheRefHolder";
}
CompilationCacheEntryRef::CompilationCacheEntryRef()
: parent_(nullptr), entry_(nullptr), index_(0) {}
CompilationCacheEntryRef::CompilationCacheEntryRef(
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
: parent_(parent), entry_(entry), index_(index) {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
entry_->Ref();
} else {
// This is a sharding/unsharding entry nested in a main entry. Only
// refcount the main entry.
entry_->main_entry->Ref();
}
}
CompilationCacheEntryRef::~CompilationCacheEntryRef() {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
parent_->DiscardEntryRefs({entry_});
} else {
parent_->DiscardEntryRefs({entry_->main_entry});
}
}
TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
if (entry_ == nullptr) {
// Create an empty entry if the entry is nullptr. This corresponds to
// non-existing sharding/unsharding entries.
return TpuCompilationCacheEntry();
}
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
}
Status CompilationCacheEntryRef::ToSubEntryRef(
CompilationCacheFetchTarget fetch_target) {
CompiledSubgraph* target = nullptr;
switch (fetch_target) {
case CompilationCacheFetchTarget::MAIN:
target = entry_;
break;
case CompilationCacheFetchTarget::SHARDING:
target = entry_->sharding_entry.get();
break;
case CompilationCacheFetchTarget::UNSHARDING:
target = entry_->unsharding_entry.get();
break;
default:
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
}
if (target == nullptr) {
// Cache entry does not have an unsharding subentry. Unref and replace
// with nullptr.
parent_->DiscardEntryRefs({entry_});
}
// Otherwise, since the refcount is always on the main entry, we don't
// need ref/unref.
entry_ = target;
return Status::OK();
}
TpuCompilationCacheInterface::TpuCompilationCacheInterface(int64 max_cache_size)
: max_cache_size_(max_cache_size) {
CHECK_GE(max_cache_size_, 0);
@ -223,7 +156,7 @@ void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
entry->Unref();
}
size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
auto erased = cache_.erase(key);
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
@ -263,7 +196,7 @@ CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
}
erased = entries_by_uid_.erase(entry->uid);
CHECK_EQ(erased, 1);
for (const std::string& key : entry->proto_key) {
for (const string& key : entry->proto_key) {
erased = entries_by_proto_key_.erase(key);
CHECK_EQ(erased, 1);
}
@ -336,10 +269,10 @@ void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
}
}
void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
void TpuCompilationCacheInterface::InsertEntry(const string& key,
CompiledSubgraph* entry) {
auto cache_inserted =
cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
CHECK(cache_inserted.second);
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
@ -362,8 +295,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<bool>* may_modify_variables,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
std::vector<CompiledSubgraph*> removed_entries;
@ -376,7 +308,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
return status;
}
std::string TpuCompilationCacheInterface::FindCacheKey(
string TpuCompilationCacheInterface::FindCacheKey(
const TpuCompilationCacheKey& subgraph_key) {
if (!subgraph_key.has_guaranteed_const) {
return subgraph_key.prefix;
@ -399,8 +331,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<bool>* may_modify_variables,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
@ -414,18 +345,17 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
// for the lifetime of the object, see InitializeEntry() call below.
absl::MutexLock lock(&mu_);
std::string cache_key = FindCacheKey(subgraph_key);
string cache_key = FindCacheKey(subgraph_key);
auto iter = cache_.find(cache_key);
bool is_new_key = iter == cache_.end();
const std::string session_name =
tpu::SessionNameFromMetadata(session_metadata);
const string session_name = tpu::SessionNameFromMetadata(session_metadata);
if (is_new_key) {
cache_key = subgraph_key.ToString();
TpuCompilationMetrics::IncrementCacheLookupCount(
/*is_cache_hit=*/false, session_name);
const std::string msg =
const string msg =
strings::StrCat("TPU host compilation cache miss: cache_key(",
cache_key, "), session_name(", session_name, ")");
TRACESTRING(msg);
@ -434,7 +364,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
// Check if caller has disabled compilation. Set using
// internal::ScopedTpuCompileDisabler.
if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
const std::string error_msg = strings::StrCat(
const string error_msg = strings::StrCat(
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
"disabled, session_name(",
session_name, ") Debug String: ", subgraph_key.debug_string);
@ -473,7 +403,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
} else {
TpuCompilationMetrics::IncrementCacheLookupCount(
/*is_cache_hit=*/true, session_name);
const std::string msg =
const string msg =
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
"), session_name(", session_name, ")");
TRACESTRING(msg);
@ -536,8 +466,8 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
return entry->initialization_status;
}
Status TpuCompilationCacheInterface::GetKeysFromUid(
int64 uid, std::vector<std::string>* keys) {
Status TpuCompilationCacheInterface::GetKeysFromUid(int64 uid,
std::vector<string>* keys) {
keys->clear();
absl::MutexLock lock(&mu_);
@ -549,49 +479,5 @@ Status TpuCompilationCacheInterface::GetKeysFromUid(
return Status::OK();
}
Status TpuCompilationCacheInterface::Lookup(
int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme(
"TPU compilation cache proto lookup by uid",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_uid_.find(uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No subgraph found for uid ", uid);
}
CompiledSubgraph* cache_entry = iter->second;
if (proto_index < 0 ||
proto_index >= cache_entry->tpu_program_group->program_count()) {
return errors::NotFound("No proto found for core index ", proto_index,
" in subgraph with uid ", uid);
}
*entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
proto_index);
return Status::OK();
}
Status TpuCompilationCacheInterface::Lookup(
const std::string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_proto_key_.find(proto_key);
if (iter == entries_by_proto_key_.end()) {
return errors::NotFound("No proto found for key ", proto_key);
}
CompiledSubgraph* cache_entry = iter->second.first;
int proto_index = iter->second.second;
*entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
proto_index);
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
#include "tensorflow/core/tpu/kernels/trace_util.h"
@ -49,20 +48,18 @@ class CompilationRefHolder : public ResourceBase {
~CompilationRefHolder() override = default;
};
// Wrapper for a cache entry returned by all the TpuCompilationCacheInterface
// `Lookup` methods, and ensures the underlying proto is not garbage-collected
// until the client discards the ptr.
// Base class for a reference to a cached tpu program. A unique_ptr to a
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
// and ensures the underlying proto is not garbage-collected until the client
// discards the ptr.
template <typename CacheEntryType>
class CompilationCacheEntryRef {
public:
CompilationCacheEntryRef();
CompilationCacheEntryRef(TpuCompilationCacheInterface* parent,
CompiledSubgraph* entry, int index);
virtual ~CompilationCacheEntryRef() = default;
virtual ~CompilationCacheEntryRef();
// Returns a TpuCompilationCacheEntry that should not be used beyond the
// lifetime of the CompilationCacheEntryRef.
virtual TpuCompilationCacheEntry get();
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
// of the tpu::CompilationCacheEntryRef.
virtual CacheEntryType get() = 0;
// Mutates this ref to point to the entry's subentry (for
// sharding/unsharding) or main entry (unchanged) as specified by
@ -72,15 +69,7 @@ class CompilationCacheEntryRef {
//
// If the requested subentry does not exist, the ref will point to a nullptr
// entry, and the original entry will be unref'ed.
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
protected:
TpuCompilationCacheInterface* parent_; // Not owned.
// A reference to entry_ is acquired in the constructor and released via
// parent->DiscardEntryRefs in the destructor.
CompiledSubgraph* entry_;
// The index of the program in entry_ that is returned by the get method.
int index_;
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
};
class TpuCompilationCacheInterface : public ResourceBase {
@ -108,8 +97,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<bool>* may_modify_variables,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
@ -136,18 +124,19 @@ class TpuCompilationCacheInterface : public ResourceBase {
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by key. On success a pointer to an EntryRef
// holding the program is returned in entry.
Status Lookup(const std::string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry);
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by uid. On success a pointer to an EntryRef
// holding the program is returned in entry.
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry);
std::unique_ptr<CacheEntryRef>* entry);
// Looks up the subgraph represented by uid, and returns the vector of keys,
// one per core, corresponding to that subgraph.
Status GetKeysFromUid(int64 uid, std::vector<std::string>* keys);
Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
// Makes a reference holder for this cache, that can be stored in the per-step
// resource manager and will ensure that compiled entries persist until the
@ -181,7 +170,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
// parent_->DiscardEntryRefs.
void AddRef(CompiledSubgraph* entry);
std::string DebugString() const override;
string DebugString() const override;
private:
TpuCompilationCacheInterface* parent_; // Not owned.
@ -196,8 +185,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
std::vector<bool>* may_modify_variables,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
@ -242,14 +230,14 @@ class TpuCompilationCacheInterface : public ResourceBase {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Removes the entry with given key from cache.
size_t RemoveEntry(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Inserts the given key and entry to cache.
void InsertEntry(const std::string& key, CompiledSubgraph* entry)
void InsertEntry(const string& key, CompiledSubgraph* entry)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the cache key matching given subgraph_key.
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Creates a new entry by running initialize_programs and places it in the
@ -259,7 +247,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
//
// **InitializeEntry releases mu_ during the call to initialize_programs.**
virtual CompiledSubgraph* InitializeEntry(
const std::string& key,
const string& key,
const std::function<Status(TpuProgramGroupInterface*)>&
initialize_programs,
const TpuCompilationCacheKey& subgraph_key)
@ -288,16 +276,13 @@ class TpuCompilationCacheInterface : public ResourceBase {
// cache_ key matching a given subgraph key. When doing a lookup, check
// session_key_map_ first to avoid unnecessay fingerprint computation.
// Map from key prefix + session_handle to a cache_ key.
absl::node_hash_map<std::string, std::string> session_key_map_
ABSL_GUARDED_BY(mu_);
absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
// Map from key prefix + fingerprint to a cache_ key.
absl::node_hash_map<std::string, std::string> fingerprint_key_map_
ABSL_GUARDED_BY(mu_);
absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache. An entry is
// marked for eviction iff it is present in cache_ and not in
// entries_by_last_use_.
std::unordered_map<std::string, CompiledSubgraph*> cache_
ABSL_GUARDED_BY(mu_);
std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache, indexed by
// uid.
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
@ -305,7 +290,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
// All the protos that can be looked up in the cache, indexed by proto
// key. The value of the map is a subgraph and the index of the proto compiled
// for that subgraph.
std::unordered_map<std::string, std::pair<CompiledSubgraph*, int>>
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
// Map from last_use to entry, used to mark entries for eviction in LRU
// order. If an entry's last_use counter is not present as a key in
@ -319,6 +304,50 @@ class TpuCompilationCacheInterface : public ResourceBase {
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
delete;
};
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status TpuCompilationCacheInterface::Lookup(
int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme(
"TPU compilation cache proto lookup by uid",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_uid_.find(uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No subgraph found for uid ", uid);
}
CompiledSubgraph* cache_entry = iter->second;
if (proto_index < 0 ||
proto_index >= cache_entry->tpu_program_group->program_count()) {
return errors::NotFound("No proto found for core index ", proto_index,
" in subgraph with uid ", uid);
}
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
return Status::OK();
}
template <typename CacheEntryRef, typename CacheEntryRefImpl>
Status TpuCompilationCacheInterface::Lookup(
const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_proto_key_.find(proto_key);
if (iter == entries_by_proto_key_.end()) {
return errors::NotFound("No proto found for key ", proto_key);
}
CompiledSubgraph* cache_entry = iter->second.first;
int proto_index = iter->second.second;
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -16,50 +16,70 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
namespace {
class CompilationCacheFetchTargetUtility {
public:
CompilationCacheFetchTargetUtility()
: names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
std::string name(CompilationCacheFetchTarget target) const {
return names_[static_cast<int>(target)];
}
private:
const std::vector<std::string> names_;
};
std::string GetName(CompilationCacheFetchTarget target) {
static const auto* util = new CompilationCacheFetchTargetUtility();
return util->name(target);
}
} // namespace
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
TpuCompilationCacheInterface* cache)
: cache_(cache) {
cache_->Ref();
}
: cache_(cache) {}
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
cache_->Unref();
}
Status TpuCompilationCacheLocalLookup::Lookup(
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
const string& proto_key,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
/*level=*/2);
Status s = cache_->Lookup(proto_key, entry);
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
proto_key, entry);
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
<< s;
if (!s.ok()) {
return s;
}
s = (*entry)->ToSubEntryRef(fetch_target);
VLOG(1) << "Fetched subentry: "
<< CompilationCacheFetchTarget_Name(fetch_target) << " with status "
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
<< s;
return s;
}
Status TpuCompilationCacheLocalLookup::Lookup(
int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
/*level=*/2);
Status s = cache_->Lookup(uid, proto_index, entry);
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
uid, proto_index, entry);
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
<< " in local subgraph cache status " << s;
if (!s.ok()) {
return s;
}
s = (*entry)->ToSubEntryRef(fetch_target);
VLOG(1) << "Fetched subentry: "
<< CompilationCacheFetchTarget_Name(fetch_target) << " with status "
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
<< s;
return s;
}
@ -67,5 +87,6 @@ Status TpuCompilationCacheLocalLookup::Lookup(
string TpuCompilationCacheLocalLookup::DebugString() const {
return "TpuCompilationCacheLocalLookup";
}
} // namespace tpu
} // namespace tensorflow

View File

@ -28,17 +28,24 @@ namespace tpu {
// Class for looking up TPU programs when the execute and compile Op are in the
// same address space. The proto is simply looked up in the compilation cache,
// without any serialization taking place.
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
class TpuCompilationCacheLocalLookup
: public TpuCompilationCacheLookup<
CompilationCacheEntryRef<TpuCompilationCacheEntry>> {
public:
using TpuCompilationCacheEntryRef =
::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
using EntryRefImpl =
::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
~TpuCompilationCacheLocalLookup() override;
Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) override;
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry,
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) override;
string DebugString() const override;

View File

@ -23,11 +23,10 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
// TODO(b/162241759): consider merging TpuCompilationCacheLookup and
// TpuCompilationCacheInterface.
// Base class allowing Execute Ops to look up TPU programs. Different subclasses
// are used when the execute Op is in the same address space as the compile Op,
// and when they need to communicate over RPC.
template <typename TpuCompilationCacheEntryRefType>
class TpuCompilationCacheLookup : public ResourceBase {
public:
~TpuCompilationCacheLookup() override = default;
@ -44,11 +43,12 @@ class TpuCompilationCacheLookup : public ResourceBase {
// fetch_target requests one of them, then after this call
// (*entry)->get().get_executable() will return nullptr.
virtual Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry,
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry,
CompilationCacheFetchTarget fetch_target) = 0;
virtual Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
virtual Status Lookup(
const string& proto_key,
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry) {
return Lookup(proto_key, std::move(entry),
CompilationCacheFetchTarget::MAIN);
}
@ -58,15 +58,17 @@ class TpuCompilationCacheLookup : public ResourceBase {
// returned in program. The wrapper is guaranteed to be valid only during the
// execution of the Op requesting the proto.
virtual Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry,
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry,
CompilationCacheFetchTarget fetch_target) = 0;
virtual Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
virtual Status Lookup(
int64 uid, int proto_index,
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry) {
return Lookup(uid, proto_index, std::move(entry),
CompilationCacheFetchTarget::MAIN);
}
};
} // namespace tpu
} // namespace tensorflow

View File

@ -413,6 +413,46 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
return Status::OK();
}
/* static */
Status TpuCompileOpKernelCommon::ComputeArgumentShapes(
const tpu::TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes) {
arg_shapes->resize(metadata.args_size());
int dynamic_shape_pos = 0;
for (int i = 0; i < metadata.args_size(); ++i) {
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
// The XLA compiler determines the shape of each constant by inspecting the
// value of its corresponding host-memory tensor. As a result, we don't need
// to give the compiler graph-inferred shapes for constant arguments.
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
continue;
}
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
PartialTensorShape static_shape(arg.shape());
TensorShape& shape = (*arg_shapes)[i];
if (static_shape.IsFullyDefined()) {
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
} else {
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
<< "Too few dynamic shapes";
shape = dynamic_shapes[dynamic_shape_pos++];
if (!static_shape.IsCompatibleWith(shape)) {
return errors::InvalidArgument(
"Mismatch between static and dynamic shape for argument. Static "
"shape: ",
static_shape.DebugString(),
"; dynamic shape: ", shape.DebugString());
}
}
}
// Checks we consumed all of the dynamic shapes.
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
<< "Too many dynamic shapes";
return Status::OK();
}
// Function arguments and return values lose their device assignments, so we
// must recreate them.
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(

View File

@ -99,6 +99,15 @@ class TpuCompileOpKernelCommon {
const std::vector<TensorShape>& arg_shapes,
TpuProgramGroupInterface* tpu_program_group) = 0;
// Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not
// defined. There must be one dynamic_shape for each argument with a
// partially defined shape, in index order.
static Status ComputeArgumentShapes(
const tpu::TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes);
// Performs shape inference on `computation`, filling shape_info with operator
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
static Status RunShapeInferenceOnComputation(

View File

@ -540,43 +540,5 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
}
return Status::OK();
}
Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes) {
arg_shapes->resize(metadata.args_size());
int dynamic_shape_pos = 0;
for (int i = 0; i < metadata.args_size(); ++i) {
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
// The XLA compiler determines the shape of each constant by inspecting the
// value of its corresponding host-memory tensor. As a result, we don't need
// to give the compiler graph-inferred shapes for constant arguments.
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
continue;
}
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
PartialTensorShape static_shape(arg.shape());
TensorShape& shape = (*arg_shapes)[i];
if (static_shape.IsFullyDefined()) {
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
} else {
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
<< "Too few dynamic shapes";
shape = dynamic_shapes[dynamic_shape_pos++];
if (!static_shape.IsCompatibleWith(shape)) {
return errors::InvalidArgument(
"Mismatch between static and dynamic shape for argument. Static "
"shape: ",
static_shape.DebugString(),
"; dynamic shape: ", shape.DebugString());
}
}
}
// Checks we consumed all of the dynamic shapes.
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
<< "Too many dynamic shapes";
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -159,14 +159,6 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
TPUCompileMetadataProto* metadata,
NameAttrList* function_name,
std::string* mlir_module);
// Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not
// defined. There must be one dynamic_shape for each argument with a
// partially defined shape, in index order.
Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& dynamic_shapes,
std::vector<TensorShape>* arg_shapes);
} // namespace tpu
} // namespace tensorflow

View File

@ -25,8 +25,6 @@ limitations under the License.
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/tpu_api.h"
@ -255,10 +253,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
mesh_state_interface));
}
VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
rmgr, tpu::kCompiledProtoCacheResourceName));
if (enable_whole_mesh_compilations_) {
// If this is a whole mesh compilation mode, create the compilation cache,
// if missing.
@ -282,13 +276,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
if (local_compilation_cache != nullptr) {
local_compilation_cache->Unref();
tpu::TpuCompilationCacheLookup* proto_lookup;
proto_lookup =
new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
OP_REQUIRES_OK(
ctx, rmgr->Create(rmgr->default_container(),
tpu::kCompiledProtoCacheResourceName, proto_lookup));
}
Tensor* ctx_output;

View File

@ -40,12 +40,10 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
@ -58,10 +56,14 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
namespace tensorflow {
namespace {
using ::tensorflow::tpu::CompilationCacheEntryRef;
using ::tensorflow::tpu::TpuCompilationCacheLookup;
using ::tensorflow::tpu::TpuNodeContext;
using CompilationCacheEntryRef = ::tensorflow::tpu::CompilationCacheEntryRef<
::tensorflow::tpu::TpuCompilationCacheEntry>;
using TpuCompilationCacheLookup =
::tensorflow::tpu::TpuCompilationCacheLookup<CompilationCacheEntryRef>;
// Looks up the input `key` in the compilation cache, populating
// `*rendezvous_key_base` and `*entry`.
@ -639,35 +641,28 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
string rendezvous_key_base;
std::unique_ptr<CompilationCacheEntryRef> entry_ref;
std::unique_ptr<CompilationCacheEntryRef> entry;
TF_RETURN_IF_ERROR(
GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
GetComputationCacheEntry(context, &rendezvous_key_base, &entry));
// Shapes of the inputs and outputs, in xla::Shape form.
tpu::TpuCompilationCacheEntry entry = entry_ref->get();
const tpu::TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<const tpu::TpuProgramGroup*>(
entry.tpu_program_group());
CHECK_NE(tpu_program_group, nullptr);
const int core_index = entry.core_index();
const TPUExecutableInfoProto& executable =
tpu_program_group->executable_info(core_index);
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
xla::Backend* const backend = node_context->backend();
xla::TransferManager* const transfer_manager = backend->transfer_manager();
TF_RET_CHECK(context->op_device_context());
se::Stream* stream = context->op_device_context()->stream();
TF_RET_CHECK(executable.input_shapes_size() == 1);
TF_RET_CHECK(proto->input_shapes_size() == 1);
xla::Shape host_shape(executable.input_shapes(0));
xla::Shape host_shape(proto->input_shapes(0));
TF_ASSIGN_OR_RETURN(
auto variable_update_map,
BuildVariableUpdateMap(executable.variable_indices(),
BuildVariableUpdateMap(proto->variable_indices(),
fused_device_var_reads_in_computation_inputs_,
fused_device_var_updates_in_computation_outputs_,
executable.output_tensor_shapes().size()));
proto->output_tensor_shapes().size()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<InputBuffers> input_buffers,
BuildComputationInputs(context, host_shape, variable_update_map, backend,
@ -702,9 +697,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
// Snapshot the inputs, if a snapshot was requested.
std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
if (executable.has_session_module()) {
hlo_snapshot =
std::make_shared<xla::HloSnapshot>(executable.session_module());
if (proto->has_session_module()) {
hlo_snapshot = std::make_shared<xla::HloSnapshot>(proto->session_module());
auto literal =
std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
transfer_manager->TransferLiteralFromDevice(
@ -729,9 +723,9 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
const uint32 rng_seed = GetXLARandomSeed();
std::unique_ptr<xla::DeviceAssignment> device_assignment;
if (executable.has_device_assignment()) {
if (proto->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
executable.device_assignment()));
proto->device_assignment()));
}
VLOG(4) << "Input buffers after alias resolution: "
@ -749,24 +743,24 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
// we free a memory and reassign it to other users while a program is running,
// all subsequent writes to the program that could possibly clobber the memory
// will depend on the program to finish.
const TPUHostTransferInfoProto& host_transfer_info =
tpu_program_group->host_transfer_info(core_index);
const TPUHostTransferInfoProto* host_transfer_info =
entry->get().get_host_transfer_info();
const xla::HloProto* hlo_metadata = entry->get().get_hlo_metadata();
TF_ASSIGN_OR_RETURN(
xla::ExecutionOutput output,
TPUExecute(executable, host_transfer_info,
*tpu_program_group->hlo_metadata(core_index), std::move(input),
TPUExecute(*proto, *host_transfer_info, *hlo_metadata, std::move(input),
rendezvous_key_base, rng_seed, node_context.get(),
device_assignment.get(), context->cancellation_manager(),
context, stream, transfer_stream_ptr.get(),
tpu_program_group->tpu_program(core_index)));
entry->get().get_tpu_program()));
stream->ThenRecordEvent(definition_event.get());
TF_ASSIGN_OR_RETURN(
std::unique_ptr<OutputBuffers> output_buffers,
AllocateOutputTensors(
context, output.ConsumeResult(), executable.output_tensor_shapes(),
variable_update_map, node_context.get(), stream, device_ordinal,
input_buffers.get(), definition_event));
AllocateOutputTensors(context, output.ConsumeResult(),
proto->output_tensor_shapes(), variable_update_map,
node_context.get(), stream, device_ordinal,
input_buffers.get(), definition_event));
// Transfer the outputs and save the snapshot to disk.
if (hlo_snapshot) {

View File

@ -21,9 +21,6 @@ limitations under the License.
typedef struct XLA_TpuProgram XLA_TpuProgram;
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
extern "C" {
// Creates a new TPU program.
@ -67,15 +64,6 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
// Check if TPU program has sharding.
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
const XLA_TpuProgram* tpu_program);
// Gets TPU program by sharding type. Return value is valid only when the
// `status.status()` returns `OK`.
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
struct TfTpu_TpuProgramApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
@ -88,8 +76,6 @@ struct TfTpu_TpuProgramApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
};
} // extern "C"

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
@ -99,71 +98,55 @@ StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
per_core_variable_indices, device_assignment);
}
} // namespace
void TpuProgramGroup::Initialize(
absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
Status CreateTpuProgramGroup(
absl::Span<XLA_TpuProgram* const> xla_tpu_programs,
TpuProgramGroupInterface* tpu_program_group_interface) {
CHECK_GT(xla_tpu_programs.size(), 0);
set_tpu_programs(xla_tpu_programs);
TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
CHECK_NE(tpu_program_group, nullptr);
tpu_program_group->set_tpu_programs(xla_tpu_programs);
std::vector<bool> may_modify_variables_array(xla_tpu_programs.size(), false);
std::vector<TPUExecutableInfoProto> executable_infos(xla_tpu_programs.size());
std::vector<TPUHostTransferInfoProto> host_transfer_infos(
xla_tpu_programs.size());
std::vector<xla::HloProto> hlo_metadatas(xla_tpu_programs.size());
for (size_t i = 0; i < xla_tpu_programs.size(); ++i) {
const XLA_TpuProgram* xla_tpu_program = xla_tpu_programs[i];
bool may_modify_variables;
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
xla_tpu_program, &may_modify_variables);
may_modify_variables_array[i] = may_modify_variables;
// TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
bool may_modify_variables;
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_programs[0],
&may_modify_variables);
tpu_program_group->set_may_modify_variables(
std::vector<bool>(1, may_modify_variables));
TpuSerializedProto serialized_executable_info;
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
xla_tpu_program, &serialized_executable_info);
TPUExecutableInfoProto executable_info =
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
serialized_executable_info);
executable_infos[i] = executable_info;
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
TpuSerializedProto serialized_executable_info;
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
xla_tpu_programs[0], &serialized_executable_info);
TPUExecutableInfoProto executable_info =
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
serialized_executable_info);
tpu_program_group->set_executable_info(executable_info);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
TPUHostTransferInfoProto host_transfer_info;
TpuSerializedProto serialized_host_transfer_info;
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
xla_tpu_program, &serialized_host_transfer_info);
if (serialized_host_transfer_info.size > 0) {
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
serialized_host_transfer_info);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
}
host_transfer_infos[i] = host_transfer_info;
TpuSerializedProto serialized_hlo_metadata;
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_program,
&serialized_hlo_metadata);
xla::HloProto hlo_metadata =
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
hlo_metadatas[i] = hlo_metadata;
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
TPUHostTransferInfoProto host_transfer_info;
TpuSerializedProto serialized_host_transfer_info;
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
xla_tpu_programs[0], &serialized_host_transfer_info);
if (serialized_host_transfer_info.size > 0) {
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
serialized_host_transfer_info);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
}
tpu_program_group->set_host_transfer_info(host_transfer_info);
may_modify_variables_ = may_modify_variables_array;
executable_infos_ = executable_infos;
host_transfer_infos_ = host_transfer_infos;
hlo_metadatas_ = hlo_metadatas;
RefreshHloMetadatasPtrs();
TpuSerializedProto serialized_hlo_metadata;
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_programs[0],
&serialized_hlo_metadata);
xla::HloProto hlo_metadata =
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
tpu_program_group->set_hlo_metadata(hlo_metadata);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
return Status::OK();
}
bool TpuProgramGroup::has_sharding_program() const {
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
return false;
}
}
return true;
}
size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
} // namespace
int64_t TpuProgramGroup::program_size() const {
int64_t total_size = 0;
@ -218,6 +201,12 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
TF_RET_CHECK(per_core_output_shapes.size() ==
per_core_variable_indices.size());
// TODO(henrytan): add an interface to TpuProgramGroupInterface to set
// may_modify_variables.
TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
tpu_program_group->may_modify_variables_ = may_modify_variables;
// With shardable input/output pairs, XLA could generate separate
// sharding/unsharding programs along with the main program. The
// sharding/unsharding programs will be in nested entries of the AOT
@ -232,20 +221,17 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
xla_tpu_programs.size() == metadata.num_cores_per_replica());
// TODO(henrytan): add an interface to TpuProgramGroupInterface to set
// may_modify_variables.
TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
tpu_program_group->Initialize(xla_tpu_programs);
tpu_program_group->may_modify_variables_ = may_modify_variables;
TF_RETURN_IF_ERROR(
CreateTpuProgramGroup(xla_tpu_programs, tpu_program_group));
return Status::OK();
}
TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
: may_modify_variables_(std::move(other.may_modify_variables_)),
host_compute_metadata_(std::move(other.host_compute_metadata_)),
tpu_programs_(std::move(other.tpu_programs_)),
executable_infos_(std::move(other.executable_infos_)),
host_transfer_infos_(std::move(other.host_transfer_infos_)),
executable_info_(std::move(other.executable_info_)),
host_transfer_info_(std::move(other.host_transfer_info_)),
hlo_metadatas_(std::move(other.hlo_metadatas_)) {
RefreshHloMetadatasPtrs();
}
@ -262,12 +248,6 @@ absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
return hlo_metadatas_ptrs_;
}
const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, hlo_metadatas_ptrs_.size());
return hlo_metadatas_ptrs_[index];
}
void TpuProgramGroup::RefreshHloMetadatasPtrs() {
hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
@ -282,47 +262,6 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
return Status::OK();
}
const std::vector<bool>& TpuProgramGroup::may_modify_variables() const {
return may_modify_variables_;
}
void TpuProgramGroup::set_may_modify_variables(
const std::vector<bool>& may_modify_variables) {
may_modify_variables_ = may_modify_variables;
}
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
return tpu_programs_;
}
const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
return tpu_programs_[index];
}
void TpuProgramGroup::set_tpu_programs(
absl::Span<XLA_TpuProgram* const> tpu_programs) {
tpu_programs_.resize(tpu_programs.size());
for (size_t i = 0; i < tpu_programs.size(); ++i) {
tpu_programs_[i] = tpu_programs[i];
}
}
const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, executable_infos_.size());
return executable_infos_[index];
}
const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, host_transfer_infos_.size());
return host_transfer_infos_[index];
}
/*static*/
Status TpuProgramGroup::CompileAndBuild(
const TpuCompilationRequestProto& compilation_request,
@ -348,27 +287,15 @@ Status TpuProgramGroup::CompileAndBuild(
TF_RET_CHECK(count == 1 ||
count == compilation_request.metadata().num_cores_per_replica());
VLOG(1) << "Initialize TpuProgramGroup.";
TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
tpu_program_group->Initialize(
absl::MakeConstSpan(&xla_tpu_programs[0], count));
VLOG(1) << "CreateTpuProgramGroup";
Status serialize_status =
CreateTpuProgramGroup(absl::MakeConstSpan(&xla_tpu_programs[0], count),
tpu_program_group_interface);
VLOG(1) << absl::StrCat("Run CreateTpuProgramGroup completed. StatusCode: ",
serialize_status.code());
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
return status.status();
return serialize_status;
}
std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
TpuProgramShardingType sharding_type) const {
std::vector<XLA_TpuProgram*> tpu_programs;
tpu_programs.reserve(tpu_programs_.size());
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn(
tpu_programs_[i], sharding_type));
CHECK_NE(tpu_programs[i], nullptr);
}
}
return tpu_programs;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -102,16 +102,11 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
TpuProgramGroupInterface* tpu_program_group_interface);
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
TpuProgramGroup() = default;
TpuProgramGroup(TpuProgramGroup&& other);
TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
bool has_sharding_program() const override;
size_t program_count() const override;
size_t program_count() const override { return tpu_programs_.size(); }
int64_t program_size() const override;
@ -122,29 +117,58 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
Status LogCompilationStats(const TpuCompilationCacheKey& key,
absl::Duration duration) override;
const std::vector<bool>& may_modify_variables() const override;
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
const std::vector<bool>& may_modify_variables() const override {
return may_modify_variables_;
}
void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
may_modify_variables_ = may_modify_variables;
}
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
const XLA_TpuProgram* tpu_program(int index) const;
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
const tf2xla::HostComputeMetadata& host_compute_metadata() const {
return host_compute_metadata_;
}
void set_host_compute_metadata(
const tf2xla::HostComputeMetadata& host_compute_metadata) {
host_compute_metadata_ = host_compute_metadata;
}
const TPUExecutableInfoProto& executable_info(int index) const;
const std::vector<XLA_TpuProgram*>& tpu_programs() const {
return tpu_programs_;
}
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs) {
tpu_programs_.resize(tpu_programs.size());
for (size_t i = 0; i < tpu_programs.size(); ++i) {
tpu_programs_[i] = tpu_programs[i];
}
}
const TPUExecutableInfoProto& executable_info() const {
return executable_info_;
}
void set_executable_info(const TPUExecutableInfoProto& executable_info) {
executable_info_ = executable_info;
}
const TPUHostTransferInfoProto& host_transfer_info() const {
return host_transfer_info_;
}
void set_host_transfer_info(
const TPUHostTransferInfoProto& host_transfer_info) {
host_transfer_info_ = host_transfer_info;
}
const TPUHostTransferInfoProto& host_transfer_info(int index) const;
void set_hlo_metadata(const xla::HloProto& hlo_metadata);
const xla::HloProto* hlo_metadata(int index) const;
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
private:
void RefreshHloMetadatasPtrs();
std::vector<bool> may_modify_variables_;
tf2xla::HostComputeMetadata host_compute_metadata_;
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
std::vector<TPUExecutableInfoProto> executable_infos_;
std::vector<TPUHostTransferInfoProto> host_transfer_infos_;
TPUExecutableInfoProto executable_info_;
TPUHostTransferInfoProto host_transfer_info_;
// To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
// signature, we store HloProto values in hlo_metadatas_ when

View File

@ -20,8 +20,6 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/core/status.h"
@ -36,16 +34,13 @@ class TpuProgramGroupInterface {
public:
virtual ~TpuProgramGroupInterface() = default;
// Check if whether sharding/unsharding program exists.
virtual bool has_sharding_program() const = 0;
// Computes program count.
virtual size_t program_count() const = 0;
// Computes total program size.
virtual int64_t program_size() const = 0;
// Unloads and destroys safely TPU programs.
// Unloads and destroys safely Tpu programs.
virtual void UnloadAndDestroyPrograms() = 0;
// Logs program memory summary.

View File

@ -64,8 +64,6 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
return tensorflow::Status::OK();
}