Refactor TpuCompilationCacheEntry interface to return TpuProgramGroupInterface and core_index and makes CacheEntry less transparent and move application specific logics outside of cache.
PiperOrigin-RevId: 324705343 Change-Id: I9dc421df069dbe7dc9bb57695f06e8b636fbc945
This commit is contained in:
parent
9474df4a12
commit
3cf7683cfe
@ -92,8 +92,6 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tpu_compilation_cache_factory",
|
":tpu_compilation_cache_factory",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_cache_local_lookup",
|
|
||||||
":tpu_compilation_cache_lookup",
|
|
||||||
":tpu_mesh_state_interface",
|
":tpu_mesh_state_interface",
|
||||||
":tpu_op_consts",
|
":tpu_op_consts",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
@ -210,14 +208,30 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_cache_entry",
|
name = "tpu_compilation_cache_entry",
|
||||||
|
srcs = ["tpu_compilation_cache_entry.cc"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"tpu_compilation_cache_entry.h",
|
"tpu_compilation_cache_entry.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":compiled_subgraph",
|
||||||
|
":tpu_compilation_cache_proto_cc",
|
||||||
":tpu_executable_info_proto_cc",
|
":tpu_executable_info_proto_cc",
|
||||||
":tpu_program_group_interface",
|
":tpu_program_group",
|
||||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/lib/core:refcount",
|
"//tensorflow/core/lib/core:refcount",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_compilation_cache_entry_impl",
|
||||||
|
srcs = [],
|
||||||
|
hdrs = ["tpu_compilation_cache_entry_impl.h"],
|
||||||
|
deps = [
|
||||||
|
":compiled_subgraph",
|
||||||
|
":tpu_compilation_cache_interface",
|
||||||
|
":tpu_executable_info_proto_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -288,8 +302,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||||
"//tensorflow/core/lib/core:status",
|
"//tensorflow/core/lib/core:status",
|
||||||
"@com_google_absl//absl/time",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -329,7 +341,6 @@ cc_library(
|
|||||||
hdrs = ["tpu_compilation_cache_interface.h"],
|
hdrs = ["tpu_compilation_cache_interface.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":compiled_subgraph",
|
":compiled_subgraph",
|
||||||
":tpu_compilation_cache_entry",
|
|
||||||
":tpu_compilation_cache_key",
|
":tpu_compilation_cache_key",
|
||||||
":tpu_compilation_cache_proto_cc",
|
":tpu_compilation_cache_proto_cc",
|
||||||
":tpu_compilation_metrics_hdrs",
|
":tpu_compilation_metrics_hdrs",
|
||||||
@ -361,6 +372,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":compiled_subgraph",
|
":compiled_subgraph",
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
|
":tpu_compilation_cache_entry_impl",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_cache_key",
|
":tpu_compilation_cache_key",
|
||||||
":tpu_compilation_cache_proto_cc",
|
":tpu_compilation_cache_proto_cc",
|
||||||
@ -370,7 +382,6 @@ cc_library(
|
|||||||
":tpu_compile_op_support",
|
":tpu_compile_op_support",
|
||||||
":tpu_mesh_state_interface",
|
":tpu_mesh_state_interface",
|
||||||
":tpu_op_consts",
|
":tpu_op_consts",
|
||||||
":tpu_program_c_api_hdrs",
|
|
||||||
":tpu_program_group",
|
":tpu_program_group",
|
||||||
":tpu_util",
|
":tpu_util",
|
||||||
":trace_util_hdrs",
|
":trace_util_hdrs",
|
||||||
@ -380,10 +391,10 @@ cc_library(
|
|||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
@ -604,7 +615,6 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
":tpu_compilation_cache_external",
|
":tpu_compilation_cache_external",
|
||||||
":tpu_compilation_cache_interface",
|
|
||||||
":tpu_compilation_cache_local_lookup",
|
":tpu_compilation_cache_local_lookup",
|
||||||
":tpu_compilation_cache_lookup",
|
":tpu_compilation_cache_lookup",
|
||||||
":tpu_executable_info_proto_cc",
|
":tpu_executable_info_proto_cc",
|
||||||
|
|||||||
54
tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
Normal file
54
tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
|
TpuCompilationCacheEntry::TpuCompilationCacheEntry(
|
||||||
|
const TpuProgramGroupInterface* tpu_program_group, int core_index)
|
||||||
|
: tpu_program_group_(
|
||||||
|
tensorflow::down_cast<const TpuProgramGroup*>(tpu_program_group)),
|
||||||
|
core_index_(core_index) {}
|
||||||
|
|
||||||
|
// Constructor for an empty entry.
|
||||||
|
TpuCompilationCacheEntry::TpuCompilationCacheEntry()
|
||||||
|
: tpu_program_group_(nullptr) {}
|
||||||
|
|
||||||
|
const TPUExecutableInfoProto* TpuCompilationCacheEntry::get_executable_info()
|
||||||
|
const {
|
||||||
|
return &(tpu_program_group_->executable_info());
|
||||||
|
}
|
||||||
|
|
||||||
|
const TPUHostTransferInfoProto*
|
||||||
|
TpuCompilationCacheEntry::get_host_transfer_info() const {
|
||||||
|
return &(tpu_program_group_->host_transfer_info());
|
||||||
|
}
|
||||||
|
|
||||||
|
const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const {
|
||||||
|
return tpu_program_group_->hlo_metadatas()[core_index_];
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(henrytan,jiawenhao): When should we expect more than one
|
||||||
|
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
|
||||||
|
const XLA_TpuProgram* TpuCompilationCacheEntry::get_tpu_program() const {
|
||||||
|
CHECK_EQ(tpu_program_group_->program_count(), 1);
|
||||||
|
return tpu_program_group_->tpu_programs()[core_index_];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
@ -18,32 +18,30 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to
|
// A version of `CompilationCacheEntry` to access Tpu binary program
|
||||||
// fetch a TPU program for a given TPU core index.
|
// `XLA_TpuProgram`.
|
||||||
class TpuCompilationCacheEntry {
|
class TpuCompilationCacheEntry {
|
||||||
public:
|
public:
|
||||||
explicit TpuCompilationCacheEntry(
|
explicit TpuCompilationCacheEntry(
|
||||||
const TpuProgramGroupInterface* tpu_program_group, int core_index)
|
const TpuProgramGroupInterface* tpu_program_group, int core_index);
|
||||||
: tpu_program_group_(tpu_program_group), core_index_(core_index) {}
|
|
||||||
|
|
||||||
// Constructor for an empty entry.
|
// Constructor for an empty entry.
|
||||||
TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {}
|
TpuCompilationCacheEntry();
|
||||||
|
const TPUExecutableInfoProto* get_executable_info() const;
|
||||||
const TpuProgramGroupInterface* tpu_program_group() const {
|
const TPUHostTransferInfoProto* get_host_transfer_info() const;
|
||||||
return tpu_program_group_;
|
const xla::HloProto* get_hlo_metadata() const;
|
||||||
}
|
// TODO(henrytan): maybe nicer to return C++ wrapper of `XLA_TpuProgram`
|
||||||
|
const XLA_TpuProgram* get_tpu_program() const;
|
||||||
int core_index() const { return core_index_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const TpuProgramGroupInterface* tpu_program_group_;
|
const TpuProgramGroup* tpu_program_group_;
|
||||||
int core_index_;
|
int core_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||||
|
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
// Wrapper for a cache entry that holds a reference to the entry until the
|
||||||
|
// wrapper is deleted. This wrapper is the concrete type of
|
||||||
|
// CompilationCacheEntryRef returned by Lookup.
|
||||||
|
template <typename CacheEntryType>
|
||||||
|
class CompilationCacheEntryRefImpl
|
||||||
|
: public CompilationCacheEntryRef<CacheEntryType> {
|
||||||
|
public:
|
||||||
|
CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent,
|
||||||
|
CompiledSubgraph* entry, int index);
|
||||||
|
~CompilationCacheEntryRefImpl() override;
|
||||||
|
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||||
|
// A reference to entry_ is acquired in the constructor and released via
|
||||||
|
// parent->DiscardEntryRefs in the destructor.
|
||||||
|
CompiledSubgraph* entry_;
|
||||||
|
// The index of the program in entry_ that is returned by the get method.
|
||||||
|
int index_;
|
||||||
|
};
|
||||||
|
template <typename CacheEntryType>
|
||||||
|
CompilationCacheEntryRefImpl<CacheEntryType>::CompilationCacheEntryRefImpl(
|
||||||
|
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||||
|
: parent_(parent), entry_(entry), index_(index) {
|
||||||
|
if (entry_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (entry_->main_entry == nullptr) {
|
||||||
|
entry_->Ref();
|
||||||
|
} else {
|
||||||
|
// This is a sharding/unsharding entry nested in a main entry. Only
|
||||||
|
// refcount the main entry.
|
||||||
|
entry_->main_entry->Ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename CacheEntryType>
|
||||||
|
CompilationCacheEntryRefImpl<CacheEntryType>::~CompilationCacheEntryRefImpl() {
|
||||||
|
if (entry_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (entry_->main_entry == nullptr) {
|
||||||
|
parent_->DiscardEntryRefs({entry_});
|
||||||
|
} else {
|
||||||
|
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename CacheEntryType>
|
||||||
|
Status CompilationCacheEntryRefImpl<CacheEntryType>::ToSubEntryRef(
|
||||||
|
CompilationCacheFetchTarget fetch_target) {
|
||||||
|
CompiledSubgraph* target = nullptr;
|
||||||
|
switch (fetch_target) {
|
||||||
|
case CompilationCacheFetchTarget::MAIN:
|
||||||
|
target = entry_;
|
||||||
|
break;
|
||||||
|
case CompilationCacheFetchTarget::SHARDING:
|
||||||
|
target = entry_->sharding_entry.get();
|
||||||
|
break;
|
||||||
|
case CompilationCacheFetchTarget::UNSHARDING:
|
||||||
|
target = entry_->unsharding_entry.get();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||||
|
}
|
||||||
|
if (target == nullptr) {
|
||||||
|
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||||
|
// with nullptr.
|
||||||
|
parent_->DiscardEntryRefs({entry_});
|
||||||
|
}
|
||||||
|
// Otherwise, since the refcount is always on the main entry, we don't
|
||||||
|
// need ref/unref.
|
||||||
|
entry_ = target;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_
|
||||||
@ -16,18 +16,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/platform/random.h"
|
#include "tensorflow/core/platform/random.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||||
|
|
||||||
@ -51,22 +48,23 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry,
|
|||||||
entry->tpu_program_group =
|
entry->tpu_program_group =
|
||||||
absl::make_unique<TpuProgramGroup>(std::move(tpu_program_group));
|
absl::make_unique<TpuProgramGroup>(std::move(tpu_program_group));
|
||||||
entry->initialized = true;
|
entry->initialized = true;
|
||||||
|
|
||||||
if (entry->initialization_status.ok()) {
|
|
||||||
// Compute the entries total size once all members are initialized.
|
|
||||||
entry->total_size = entry->ComputeTotalSize();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<CompiledSubgraph> CreateAndInitializeCompiledSubgraph(
|
|
||||||
CompiledSubgraph* main_entry) {
|
|
||||||
auto entry = absl::make_unique<CompiledSubgraph>();
|
|
||||||
entry->main_entry = main_entry;
|
|
||||||
entry->tpu_program_group = absl::make_unique<TpuProgramGroup>();
|
|
||||||
return entry;
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl(
|
||||||
|
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
||||||
|
: CompilationCacheEntryRefImpl<TpuCompilationCacheEntry>(parent, entry,
|
||||||
|
index) {}
|
||||||
|
|
||||||
|
TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() {
|
||||||
|
if (entry_ == nullptr) {
|
||||||
|
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||||
|
// non-existing sharding/unsharding entries.
|
||||||
|
return TpuCompilationCacheEntry();
|
||||||
|
}
|
||||||
|
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
||||||
|
}
|
||||||
|
|
||||||
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
||||||
const string& key,
|
const string& key,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
|
const std::function<Status(TpuProgramGroupInterface*)>& initialize_program,
|
||||||
@ -75,6 +73,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
|||||||
main_entry->parent = this;
|
main_entry->parent = this;
|
||||||
main_entry->subgraph_key = key;
|
main_entry->subgraph_key = key;
|
||||||
main_entry->uid = get_uid();
|
main_entry->uid = get_uid();
|
||||||
|
// TODO(henrytan): implement TpuCompilationCacheKey.debug_string.
|
||||||
main_entry->cache_entry_debug_string = subgraph_key.prefix;
|
main_entry->cache_entry_debug_string = subgraph_key.prefix;
|
||||||
VLOG(1) << "Cache Initializing Entry Session Debug "
|
VLOG(1) << "Cache Initializing Entry Session Debug "
|
||||||
<< main_entry->cache_entry_debug_string;
|
<< main_entry->cache_entry_debug_string;
|
||||||
@ -113,29 +112,17 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry(
|
|||||||
std::pair<int64, CompiledSubgraph*>(main_entry->uid, main_entry));
|
std::pair<int64, CompiledSubgraph*>(main_entry->uid, main_entry));
|
||||||
CHECK(uid_inserted.second);
|
CHECK(uid_inserted.second);
|
||||||
|
|
||||||
if (tpu_program_group.has_sharding_program()) {
|
if (initialization_status.ok()) {
|
||||||
main_entry->sharding_entry =
|
// Compute the entries total size once all members are initialized.
|
||||||
CreateAndInitializeCompiledSubgraph(main_entry);
|
main_entry->total_size = tpu_program_group.program_size();
|
||||||
TpuProgramGroup sharding_programs;
|
|
||||||
sharding_programs.Initialize(
|
|
||||||
tpu_program_group.tpu_programs(TpuProgramShardingType::kSharding));
|
|
||||||
PopulateEntry(key, main_entry->sharding_entry.get(),
|
|
||||||
std::move(sharding_programs));
|
|
||||||
|
|
||||||
main_entry->unsharding_entry =
|
|
||||||
CreateAndInitializeCompiledSubgraph(main_entry);
|
|
||||||
TpuProgramGroup unsharding_programs;
|
|
||||||
unsharding_programs.Initialize(
|
|
||||||
tpu_program_group.tpu_programs(TpuProgramShardingType::kUnsharding));
|
|
||||||
PopulateEntry(key, main_entry->unsharding_entry.get(),
|
|
||||||
std::move(unsharding_programs));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(henrytan): handle sharding/unsharding.
|
||||||
PopulateEntry(key, main_entry, std::move(tpu_program_group));
|
PopulateEntry(key, main_entry, std::move(tpu_program_group));
|
||||||
|
|
||||||
for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
|
for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
|
||||||
auto entry_inserted = entries_by_proto_key_.insert(
|
auto entry_inserted = entries_by_proto_key_.insert(
|
||||||
std::pair<std::string, std::pair<CompiledSubgraph*, int>>(
|
std::pair<string, std::pair<CompiledSubgraph*, int>>(
|
||||||
main_entry->proto_key[i], std::make_pair(main_entry, i)));
|
main_entry->proto_key[i], std::make_pair(main_entry, i)));
|
||||||
CHECK(entry_inserted.second);
|
CHECK(entry_inserted.second);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
@ -45,6 +46,17 @@ namespace tpu {
|
|||||||
|
|
||||||
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
class TpuCompilationCacheExternal : public TpuCompilationCacheInterface {
|
||||||
public:
|
public:
|
||||||
|
using Status = ::stream_executor::port::Status;
|
||||||
|
|
||||||
|
class EntryRefImpl
|
||||||
|
: public CompilationCacheEntryRefImpl<TpuCompilationCacheEntry> {
|
||||||
|
public:
|
||||||
|
EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry,
|
||||||
|
int index);
|
||||||
|
|
||||||
|
TpuCompilationCacheEntry get() override;
|
||||||
|
};
|
||||||
|
|
||||||
explicit TpuCompilationCacheExternal(int64 max_cache_size)
|
explicit TpuCompilationCacheExternal(int64 max_cache_size)
|
||||||
: TpuCompilationCacheInterface(max_cache_size) {}
|
: TpuCompilationCacheInterface(max_cache_size) {}
|
||||||
|
|
||||||
|
|||||||
@ -38,77 +38,10 @@ void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
|
|||||||
entries_.push_back(entry);
|
entries_.push_back(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
|
string TpuCompilationCacheInterface::RefHolder::DebugString() const {
|
||||||
return "TpuCompilationCacheRefHolder";
|
return "TpuCompilationCacheRefHolder";
|
||||||
}
|
}
|
||||||
|
|
||||||
CompilationCacheEntryRef::CompilationCacheEntryRef()
|
|
||||||
: parent_(nullptr), entry_(nullptr), index_(0) {}
|
|
||||||
|
|
||||||
CompilationCacheEntryRef::CompilationCacheEntryRef(
|
|
||||||
TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
|
|
||||||
: parent_(parent), entry_(entry), index_(index) {
|
|
||||||
if (entry_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (entry_->main_entry == nullptr) {
|
|
||||||
entry_->Ref();
|
|
||||||
} else {
|
|
||||||
// This is a sharding/unsharding entry nested in a main entry. Only
|
|
||||||
// refcount the main entry.
|
|
||||||
entry_->main_entry->Ref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CompilationCacheEntryRef::~CompilationCacheEntryRef() {
|
|
||||||
if (entry_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (entry_->main_entry == nullptr) {
|
|
||||||
parent_->DiscardEntryRefs({entry_});
|
|
||||||
} else {
|
|
||||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
|
|
||||||
if (entry_ == nullptr) {
|
|
||||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
|
||||||
// non-existing sharding/unsharding entries.
|
|
||||||
return TpuCompilationCacheEntry();
|
|
||||||
}
|
|
||||||
|
|
||||||
return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CompilationCacheEntryRef::ToSubEntryRef(
|
|
||||||
CompilationCacheFetchTarget fetch_target) {
|
|
||||||
CompiledSubgraph* target = nullptr;
|
|
||||||
switch (fetch_target) {
|
|
||||||
case CompilationCacheFetchTarget::MAIN:
|
|
||||||
target = entry_;
|
|
||||||
break;
|
|
||||||
case CompilationCacheFetchTarget::SHARDING:
|
|
||||||
target = entry_->sharding_entry.get();
|
|
||||||
break;
|
|
||||||
case CompilationCacheFetchTarget::UNSHARDING:
|
|
||||||
target = entry_->unsharding_entry.get();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (target == nullptr) {
|
|
||||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
|
||||||
// with nullptr.
|
|
||||||
parent_->DiscardEntryRefs({entry_});
|
|
||||||
}
|
|
||||||
// Otherwise, since the refcount is always on the main entry, we don't
|
|
||||||
// need ref/unref.
|
|
||||||
entry_ = target;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
TpuCompilationCacheInterface::TpuCompilationCacheInterface(int64 max_cache_size)
|
TpuCompilationCacheInterface::TpuCompilationCacheInterface(int64 max_cache_size)
|
||||||
: max_cache_size_(max_cache_size) {
|
: max_cache_size_(max_cache_size) {
|
||||||
CHECK_GE(max_cache_size_, 0);
|
CHECK_GE(max_cache_size_, 0);
|
||||||
@ -223,7 +156,7 @@ void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
|
|||||||
entry->Unref();
|
entry->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
|
size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
|
||||||
auto erased = cache_.erase(key);
|
auto erased = cache_.erase(key);
|
||||||
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
|
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
|
||||||
|
|
||||||
@ -263,7 +196,7 @@ CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
|
|||||||
}
|
}
|
||||||
erased = entries_by_uid_.erase(entry->uid);
|
erased = entries_by_uid_.erase(entry->uid);
|
||||||
CHECK_EQ(erased, 1);
|
CHECK_EQ(erased, 1);
|
||||||
for (const std::string& key : entry->proto_key) {
|
for (const string& key : entry->proto_key) {
|
||||||
erased = entries_by_proto_key_.erase(key);
|
erased = entries_by_proto_key_.erase(key);
|
||||||
CHECK_EQ(erased, 1);
|
CHECK_EQ(erased, 1);
|
||||||
}
|
}
|
||||||
@ -336,10 +269,10 @@ void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
|
void TpuCompilationCacheInterface::InsertEntry(const string& key,
|
||||||
CompiledSubgraph* entry) {
|
CompiledSubgraph* entry) {
|
||||||
auto cache_inserted =
|
auto cache_inserted =
|
||||||
cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
|
cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
|
||||||
CHECK(cache_inserted.second);
|
CHECK(cache_inserted.second);
|
||||||
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
|
TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
|
||||||
|
|
||||||
@ -362,8 +295,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
|||||||
const TpuCompilationCacheKey& subgraph_key,
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
std::vector<bool>* may_modify_variables,
|
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
||||||
std::vector<CompiledSubgraph*> removed_entries;
|
std::vector<CompiledSubgraph*> removed_entries;
|
||||||
@ -376,7 +308,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TpuCompilationCacheInterface::FindCacheKey(
|
string TpuCompilationCacheInterface::FindCacheKey(
|
||||||
const TpuCompilationCacheKey& subgraph_key) {
|
const TpuCompilationCacheKey& subgraph_key) {
|
||||||
if (!subgraph_key.has_guaranteed_const) {
|
if (!subgraph_key.has_guaranteed_const) {
|
||||||
return subgraph_key.prefix;
|
return subgraph_key.prefix;
|
||||||
@ -399,8 +331,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
const TpuCompilationCacheKey& subgraph_key,
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
std::vector<bool>* may_modify_variables,
|
|
||||||
std::vector<CompiledSubgraph*>* removed_entries,
|
std::vector<CompiledSubgraph*>* removed_entries,
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
|
||||||
@ -414,18 +345,17 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
// for the lifetime of the object, see InitializeEntry() call below.
|
// for the lifetime of the object, see InitializeEntry() call below.
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
std::string cache_key = FindCacheKey(subgraph_key);
|
string cache_key = FindCacheKey(subgraph_key);
|
||||||
auto iter = cache_.find(cache_key);
|
auto iter = cache_.find(cache_key);
|
||||||
bool is_new_key = iter == cache_.end();
|
bool is_new_key = iter == cache_.end();
|
||||||
|
|
||||||
const std::string session_name =
|
const string session_name = tpu::SessionNameFromMetadata(session_metadata);
|
||||||
tpu::SessionNameFromMetadata(session_metadata);
|
|
||||||
|
|
||||||
if (is_new_key) {
|
if (is_new_key) {
|
||||||
cache_key = subgraph_key.ToString();
|
cache_key = subgraph_key.ToString();
|
||||||
TpuCompilationMetrics::IncrementCacheLookupCount(
|
TpuCompilationMetrics::IncrementCacheLookupCount(
|
||||||
/*is_cache_hit=*/false, session_name);
|
/*is_cache_hit=*/false, session_name);
|
||||||
const std::string msg =
|
const string msg =
|
||||||
strings::StrCat("TPU host compilation cache miss: cache_key(",
|
strings::StrCat("TPU host compilation cache miss: cache_key(",
|
||||||
cache_key, "), session_name(", session_name, ")");
|
cache_key, "), session_name(", session_name, ")");
|
||||||
TRACESTRING(msg);
|
TRACESTRING(msg);
|
||||||
@ -434,7 +364,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
// Check if caller has disabled compilation. Set using
|
// Check if caller has disabled compilation. Set using
|
||||||
// internal::ScopedTpuCompileDisabler.
|
// internal::ScopedTpuCompileDisabler.
|
||||||
if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
|
if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
|
||||||
const std::string error_msg = strings::StrCat(
|
const string error_msg = strings::StrCat(
|
||||||
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
||||||
"disabled, session_name(",
|
"disabled, session_name(",
|
||||||
session_name, ") Debug String: ", subgraph_key.debug_string);
|
session_name, ") Debug String: ", subgraph_key.debug_string);
|
||||||
@ -473,7 +403,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
} else {
|
} else {
|
||||||
TpuCompilationMetrics::IncrementCacheLookupCount(
|
TpuCompilationMetrics::IncrementCacheLookupCount(
|
||||||
/*is_cache_hit=*/true, session_name);
|
/*is_cache_hit=*/true, session_name);
|
||||||
const std::string msg =
|
const string msg =
|
||||||
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
|
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
|
||||||
"), session_name(", session_name, ")");
|
"), session_name(", session_name, ")");
|
||||||
TRACESTRING(msg);
|
TRACESTRING(msg);
|
||||||
@ -536,8 +466,8 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
return entry->initialization_status;
|
return entry->initialization_status;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuCompilationCacheInterface::GetKeysFromUid(
|
Status TpuCompilationCacheInterface::GetKeysFromUid(int64 uid,
|
||||||
int64 uid, std::vector<std::string>* keys) {
|
std::vector<string>* keys) {
|
||||||
keys->clear();
|
keys->clear();
|
||||||
|
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
@ -549,49 +479,5 @@ Status TpuCompilationCacheInterface::GetKeysFromUid(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuCompilationCacheInterface::Lookup(
|
|
||||||
int64 uid, int proto_index,
|
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
|
||||||
entry->reset();
|
|
||||||
|
|
||||||
profiler::TraceMe proto_lookup_traceme(
|
|
||||||
"TPU compilation cache proto lookup by uid",
|
|
||||||
/*level=*/2);
|
|
||||||
|
|
||||||
absl::MutexLock lock(&mu_);
|
|
||||||
const auto iter = entries_by_uid_.find(uid);
|
|
||||||
if (iter == entries_by_uid_.end()) {
|
|
||||||
return errors::NotFound("No subgraph found for uid ", uid);
|
|
||||||
}
|
|
||||||
CompiledSubgraph* cache_entry = iter->second;
|
|
||||||
if (proto_index < 0 ||
|
|
||||||
proto_index >= cache_entry->tpu_program_group->program_count()) {
|
|
||||||
return errors::NotFound("No proto found for core index ", proto_index,
|
|
||||||
" in subgraph with uid ", uid);
|
|
||||||
}
|
|
||||||
*entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
|
|
||||||
proto_index);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TpuCompilationCacheInterface::Lookup(
|
|
||||||
const std::string& proto_key,
|
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
|
||||||
entry->reset();
|
|
||||||
|
|
||||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
|
||||||
/*level=*/2);
|
|
||||||
|
|
||||||
absl::MutexLock lock(&mu_);
|
|
||||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
|
||||||
if (iter == entries_by_proto_key_.end()) {
|
|
||||||
return errors::NotFound("No proto found for key ", proto_key);
|
|
||||||
}
|
|
||||||
CompiledSubgraph* cache_entry = iter->second.first;
|
|
||||||
int proto_index = iter->second.second;
|
|
||||||
*entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
|
|
||||||
proto_index);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -32,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
||||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||||
@ -49,20 +48,18 @@ class CompilationRefHolder : public ResourceBase {
|
|||||||
~CompilationRefHolder() override = default;
|
~CompilationRefHolder() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrapper for a cache entry returned by all the TpuCompilationCacheInterface
|
// Base class for a reference to a cached tpu program. A unique_ptr to a
|
||||||
// `Lookup` methods, and ensures the underlying proto is not garbage-collected
|
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
||||||
// until the client discards the ptr.
|
// and ensures the underlying proto is not garbage-collected until the client
|
||||||
|
// discards the ptr.
|
||||||
|
template <typename CacheEntryType>
|
||||||
class CompilationCacheEntryRef {
|
class CompilationCacheEntryRef {
|
||||||
public:
|
public:
|
||||||
CompilationCacheEntryRef();
|
virtual ~CompilationCacheEntryRef() = default;
|
||||||
CompilationCacheEntryRef(TpuCompilationCacheInterface* parent,
|
|
||||||
CompiledSubgraph* entry, int index);
|
|
||||||
|
|
||||||
virtual ~CompilationCacheEntryRef();
|
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
||||||
|
// of the tpu::CompilationCacheEntryRef.
|
||||||
// Returns a TpuCompilationCacheEntry that should not be used beyond the
|
virtual CacheEntryType get() = 0;
|
||||||
// lifetime of the CompilationCacheEntryRef.
|
|
||||||
virtual TpuCompilationCacheEntry get();
|
|
||||||
|
|
||||||
// Mutates this ref to point to the entry's subentry (for
|
// Mutates this ref to point to the entry's subentry (for
|
||||||
// sharding/unsharding) or main entry (unchanged) as specified by
|
// sharding/unsharding) or main entry (unchanged) as specified by
|
||||||
@ -72,15 +69,7 @@ class CompilationCacheEntryRef {
|
|||||||
//
|
//
|
||||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||||
// entry, and the original entry will be unref'ed.
|
// entry, and the original entry will be unref'ed.
|
||||||
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
|
virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0;
|
||||||
|
|
||||||
protected:
|
|
||||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
|
||||||
// A reference to entry_ is acquired in the constructor and released via
|
|
||||||
// parent->DiscardEntryRefs in the destructor.
|
|
||||||
CompiledSubgraph* entry_;
|
|
||||||
// The index of the program in entry_ that is returned by the get method.
|
|
||||||
int index_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TpuCompilationCacheInterface : public ResourceBase {
|
class TpuCompilationCacheInterface : public ResourceBase {
|
||||||
@ -108,8 +97,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
const TpuCompilationCacheKey& subgraph_key,
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
std::vector<bool>* may_modify_variables,
|
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||||
|
|
||||||
@ -136,18 +124,19 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
// Looks up an executable corresponding to the model-parallel core index of
|
// Looks up an executable corresponding to the model-parallel core index of
|
||||||
// the subgraph represented by key. On success a pointer to an EntryRef
|
// the subgraph represented by key. On success a pointer to an EntryRef
|
||||||
// holding the program is returned in entry.
|
// holding the program is returned in entry.
|
||||||
Status Lookup(const std::string& proto_key,
|
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
Status Lookup(const string& proto_key, std::unique_ptr<CacheEntryRef>* entry);
|
||||||
|
|
||||||
// Looks up an executable corresponding to the model-parallel core index of
|
// Looks up an executable corresponding to the model-parallel core index of
|
||||||
// the subgraph represented by uid. On success a pointer to an EntryRef
|
// the subgraph represented by uid. On success a pointer to an EntryRef
|
||||||
// holding the program is returned in entry.
|
// holding the program is returned in entry.
|
||||||
|
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||||
Status Lookup(int64 uid, int proto_index,
|
Status Lookup(int64 uid, int proto_index,
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
std::unique_ptr<CacheEntryRef>* entry);
|
||||||
|
|
||||||
// Looks up the subgraph represented by uid, and returns the vector of keys,
|
// Looks up the subgraph represented by uid, and returns the vector of keys,
|
||||||
// one per core, corresponding to that subgraph.
|
// one per core, corresponding to that subgraph.
|
||||||
Status GetKeysFromUid(int64 uid, std::vector<std::string>* keys);
|
Status GetKeysFromUid(int64 uid, std::vector<string>* keys);
|
||||||
|
|
||||||
// Makes a reference holder for this cache, that can be stored in the per-step
|
// Makes a reference holder for this cache, that can be stored in the per-step
|
||||||
// resource manager and will ensure that compiled entries persist until the
|
// resource manager and will ensure that compiled entries persist until the
|
||||||
@ -181,7 +170,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
// parent_->DiscardEntryRefs.
|
// parent_->DiscardEntryRefs.
|
||||||
void AddRef(CompiledSubgraph* entry);
|
void AddRef(CompiledSubgraph* entry);
|
||||||
|
|
||||||
std::string DebugString() const override;
|
string DebugString() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||||
@ -196,8 +185,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
const TpuCompilationCacheKey& subgraph_key,
|
const TpuCompilationCacheKey& subgraph_key,
|
||||||
const SessionMetadata* session_metadata,
|
const SessionMetadata* session_metadata,
|
||||||
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
CompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||||
std::vector<std::string>* proto_key,
|
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||||
std::vector<bool>* may_modify_variables,
|
|
||||||
std::vector<CompiledSubgraph*>* removed_entries,
|
std::vector<CompiledSubgraph*>* removed_entries,
|
||||||
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
absl::Span<const xla::HloProto* const>* hlo_metadatas,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
|
||||||
@ -242,14 +230,14 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Removes the entry with given key from cache.
|
// Removes the entry with given key from cache.
|
||||||
size_t RemoveEntry(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Inserts the given key and entry to cache.
|
// Inserts the given key and entry to cache.
|
||||||
void InsertEntry(const std::string& key, CompiledSubgraph* entry)
|
void InsertEntry(const string& key, CompiledSubgraph* entry)
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Returns the cache key matching given subgraph_key.
|
// Returns the cache key matching given subgraph_key.
|
||||||
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
|
string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Creates a new entry by running initialize_programs and places it in the
|
// Creates a new entry by running initialize_programs and places it in the
|
||||||
@ -259,7 +247,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
//
|
//
|
||||||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
||||||
virtual CompiledSubgraph* InitializeEntry(
|
virtual CompiledSubgraph* InitializeEntry(
|
||||||
const std::string& key,
|
const string& key,
|
||||||
const std::function<Status(TpuProgramGroupInterface*)>&
|
const std::function<Status(TpuProgramGroupInterface*)>&
|
||||||
initialize_programs,
|
initialize_programs,
|
||||||
const TpuCompilationCacheKey& subgraph_key)
|
const TpuCompilationCacheKey& subgraph_key)
|
||||||
@ -288,16 +276,13 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
// cache_ key matching a given subgraph key. When doing a lookup, check
|
// cache_ key matching a given subgraph key. When doing a lookup, check
|
||||||
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
||||||
// Map from key prefix + session_handle to a cache_ key.
|
// Map from key prefix + session_handle to a cache_ key.
|
||||||
absl::node_hash_map<std::string, std::string> session_key_map_
|
absl::node_hash_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
||||||
ABSL_GUARDED_BY(mu_);
|
|
||||||
// Map from key prefix + fingerprint to a cache_ key.
|
// Map from key prefix + fingerprint to a cache_ key.
|
||||||
absl::node_hash_map<std::string, std::string> fingerprint_key_map_
|
absl::node_hash_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
||||||
ABSL_GUARDED_BY(mu_);
|
|
||||||
// All the subgraph entries that can be looked up in the cache. An entry is
|
// All the subgraph entries that can be looked up in the cache. An entry is
|
||||||
// marked for eviction iff it is present in cache_ and not in
|
// marked for eviction iff it is present in cache_ and not in
|
||||||
// entries_by_last_use_.
|
// entries_by_last_use_.
|
||||||
std::unordered_map<std::string, CompiledSubgraph*> cache_
|
std::unordered_map<string, CompiledSubgraph*> cache_ ABSL_GUARDED_BY(mu_);
|
||||||
ABSL_GUARDED_BY(mu_);
|
|
||||||
// All the subgraph entries that can be looked up in the cache, indexed by
|
// All the subgraph entries that can be looked up in the cache, indexed by
|
||||||
// uid.
|
// uid.
|
||||||
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
absl::node_hash_map<int64, CompiledSubgraph*> entries_by_uid_
|
||||||
@ -305,7 +290,7 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
// All the protos that can be looked up in the cache, indexed by proto
|
// All the protos that can be looked up in the cache, indexed by proto
|
||||||
// key. The value of the map is a subgraph and the index of the proto compiled
|
// key. The value of the map is a subgraph and the index of the proto compiled
|
||||||
// for that subgraph.
|
// for that subgraph.
|
||||||
std::unordered_map<std::string, std::pair<CompiledSubgraph*, int>>
|
std::unordered_map<string, std::pair<CompiledSubgraph*, int>>
|
||||||
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
||||||
// Map from last_use to entry, used to mark entries for eviction in LRU
|
// Map from last_use to entry, used to mark entries for eviction in LRU
|
||||||
// order. If an entry's last_use counter is not present as a key in
|
// order. If an entry's last_use counter is not present as a key in
|
||||||
@ -319,6 +304,50 @@ class TpuCompilationCacheInterface : public ResourceBase {
|
|||||||
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
|
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
|
||||||
delete;
|
delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||||
|
Status TpuCompilationCacheInterface::Lookup(
|
||||||
|
int64 uid, int proto_index, std::unique_ptr<CacheEntryRef>* entry) {
|
||||||
|
entry->reset();
|
||||||
|
|
||||||
|
profiler::TraceMe proto_lookup_traceme(
|
||||||
|
"TPU compilation cache proto lookup by uid",
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
const auto iter = entries_by_uid_.find(uid);
|
||||||
|
if (iter == entries_by_uid_.end()) {
|
||||||
|
return errors::NotFound("No subgraph found for uid ", uid);
|
||||||
|
}
|
||||||
|
CompiledSubgraph* cache_entry = iter->second;
|
||||||
|
if (proto_index < 0 ||
|
||||||
|
proto_index >= cache_entry->tpu_program_group->program_count()) {
|
||||||
|
return errors::NotFound("No proto found for core index ", proto_index,
|
||||||
|
" in subgraph with uid ", uid);
|
||||||
|
}
|
||||||
|
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename CacheEntryRef, typename CacheEntryRefImpl>
|
||||||
|
Status TpuCompilationCacheInterface::Lookup(
|
||||||
|
const string& proto_key, std::unique_ptr<CacheEntryRef>* entry) {
|
||||||
|
entry->reset();
|
||||||
|
|
||||||
|
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||||
|
if (iter == entries_by_proto_key_.end()) {
|
||||||
|
return errors::NotFound("No proto found for key ", proto_key);
|
||||||
|
}
|
||||||
|
CompiledSubgraph* cache_entry = iter->second.first;
|
||||||
|
int proto_index = iter->second.second;
|
||||||
|
*entry = absl::make_unique<CacheEntryRefImpl>(this, cache_entry, proto_index);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|||||||
@ -16,50 +16,70 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
namespace {
|
||||||
|
class CompilationCacheFetchTargetUtility {
|
||||||
|
public:
|
||||||
|
CompilationCacheFetchTargetUtility()
|
||||||
|
: names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
|
||||||
|
|
||||||
|
std::string name(CompilationCacheFetchTarget target) const {
|
||||||
|
return names_[static_cast<int>(target)];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::vector<std::string> names_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string GetName(CompilationCacheFetchTarget target) {
|
||||||
|
static const auto* util = new CompilationCacheFetchTargetUtility();
|
||||||
|
return util->name(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
||||||
TpuCompilationCacheInterface* cache)
|
TpuCompilationCacheInterface* cache)
|
||||||
: cache_(cache) {
|
: cache_(cache) {}
|
||||||
cache_->Ref();
|
|
||||||
}
|
|
||||||
|
|
||||||
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
||||||
cache_->Unref();
|
cache_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||||
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
|
const string& proto_key,
|
||||||
|
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) {
|
CompilationCacheFetchTarget fetch_target) {
|
||||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
Status s = cache_->Lookup(proto_key, entry);
|
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
||||||
|
proto_key, entry);
|
||||||
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
||||||
<< s;
|
<< s;
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
s = (*entry)->ToSubEntryRef(fetch_target);
|
||||||
VLOG(1) << "Fetched subentry: "
|
|
||||||
<< CompilationCacheFetchTarget_Name(fetch_target) << " with status "
|
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||||
<< s;
|
<< s;
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||||
int64 uid, int proto_index,
|
int64 uid, int proto_index,
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) {
|
CompilationCacheFetchTarget fetch_target) {
|
||||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
Status s = cache_->Lookup(uid, proto_index, entry);
|
Status s = cache_->Lookup<TpuCompilationCacheEntryRef, EntryRefImpl>(
|
||||||
|
uid, proto_index, entry);
|
||||||
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
||||||
<< " in local subgraph cache status " << s;
|
<< " in local subgraph cache status " << s;
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
s = (*entry)->ToSubEntryRef(fetch_target);
|
s = (*entry)->ToSubEntryRef(fetch_target);
|
||||||
VLOG(1) << "Fetched subentry: "
|
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||||
<< CompilationCacheFetchTarget_Name(fetch_target) << " with status "
|
|
||||||
<< s;
|
<< s;
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
@ -67,5 +87,6 @@ Status TpuCompilationCacheLocalLookup::Lookup(
|
|||||||
string TpuCompilationCacheLocalLookup::DebugString() const {
|
string TpuCompilationCacheLocalLookup::DebugString() const {
|
||||||
return "TpuCompilationCacheLocalLookup";
|
return "TpuCompilationCacheLocalLookup";
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -28,17 +28,24 @@ namespace tpu {
|
|||||||
// Class for looking up TPU programs when the execute and compile Op are in the
|
// Class for looking up TPU programs when the execute and compile Op are in the
|
||||||
// same address space. The proto is simply looked up in the compilation cache,
|
// same address space. The proto is simply looked up in the compilation cache,
|
||||||
// without any serialization taking place.
|
// without any serialization taking place.
|
||||||
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
class TpuCompilationCacheLocalLookup
|
||||||
|
: public TpuCompilationCacheLookup<
|
||||||
|
CompilationCacheEntryRef<TpuCompilationCacheEntry>> {
|
||||||
public:
|
public:
|
||||||
|
using TpuCompilationCacheEntryRef =
|
||||||
|
::tensorflow::tpu::CompilationCacheEntryRef<TpuCompilationCacheEntry>;
|
||||||
|
using EntryRefImpl =
|
||||||
|
::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl;
|
||||||
|
|
||||||
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
|
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
|
||||||
~TpuCompilationCacheLocalLookup() override;
|
~TpuCompilationCacheLocalLookup() override;
|
||||||
|
|
||||||
Status Lookup(const string& proto_key,
|
Status Lookup(const string& proto_key,
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) override;
|
CompilationCacheFetchTarget fetch_target) override;
|
||||||
|
|
||||||
Status Lookup(int64 uid, int proto_index,
|
Status Lookup(int64 uid, int proto_index,
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
std::unique_ptr<TpuCompilationCacheEntryRef>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) override;
|
CompilationCacheFetchTarget fetch_target) override;
|
||||||
|
|
||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
|
|||||||
@ -23,11 +23,10 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
// TODO(b/162241759): consider merging TpuCompilationCacheLookup and
|
|
||||||
// TpuCompilationCacheInterface.
|
|
||||||
// Base class allowing Execute Ops to look up TPU programs. Different subclasses
|
// Base class allowing Execute Ops to look up TPU programs. Different subclasses
|
||||||
// are used when the execute Op is in the same address space as the compile Op,
|
// are used when the execute Op is in the same address space as the compile Op,
|
||||||
// and when they need to communicate over RPC.
|
// and when they need to communicate over RPC.
|
||||||
|
template <typename TpuCompilationCacheEntryRefType>
|
||||||
class TpuCompilationCacheLookup : public ResourceBase {
|
class TpuCompilationCacheLookup : public ResourceBase {
|
||||||
public:
|
public:
|
||||||
~TpuCompilationCacheLookup() override = default;
|
~TpuCompilationCacheLookup() override = default;
|
||||||
@ -44,11 +43,12 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
|||||||
// fetch_target requests one of them, then after this call
|
// fetch_target requests one of them, then after this call
|
||||||
// (*entry)->get().get_executable() will return nullptr.
|
// (*entry)->get().get_executable() will return nullptr.
|
||||||
virtual Status Lookup(const string& proto_key,
|
virtual Status Lookup(const string& proto_key,
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) = 0;
|
CompilationCacheFetchTarget fetch_target) = 0;
|
||||||
|
|
||||||
virtual Status Lookup(const string& proto_key,
|
virtual Status Lookup(
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
const string& proto_key,
|
||||||
|
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry) {
|
||||||
return Lookup(proto_key, std::move(entry),
|
return Lookup(proto_key, std::move(entry),
|
||||||
CompilationCacheFetchTarget::MAIN);
|
CompilationCacheFetchTarget::MAIN);
|
||||||
}
|
}
|
||||||
@ -58,15 +58,17 @@ class TpuCompilationCacheLookup : public ResourceBase {
|
|||||||
// returned in program. The wrapper is guaranteed to be valid only during the
|
// returned in program. The wrapper is guaranteed to be valid only during the
|
||||||
// execution of the Op requesting the proto.
|
// execution of the Op requesting the proto.
|
||||||
virtual Status Lookup(int64 uid, int proto_index,
|
virtual Status Lookup(int64 uid, int proto_index,
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry,
|
||||||
CompilationCacheFetchTarget fetch_target) = 0;
|
CompilationCacheFetchTarget fetch_target) = 0;
|
||||||
|
|
||||||
virtual Status Lookup(int64 uid, int proto_index,
|
virtual Status Lookup(
|
||||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
int64 uid, int proto_index,
|
||||||
|
std::unique_ptr<TpuCompilationCacheEntryRefType>* entry) {
|
||||||
return Lookup(uid, proto_index, std::move(entry),
|
return Lookup(uid, proto_index, std::move(entry),
|
||||||
CompilationCacheFetchTarget::MAIN);
|
CompilationCacheFetchTarget::MAIN);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|||||||
@ -413,6 +413,46 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
Status TpuCompileOpKernelCommon::ComputeArgumentShapes(
|
||||||
|
const tpu::TPUCompileMetadataProto& metadata,
|
||||||
|
const std::vector<TensorShape>& dynamic_shapes,
|
||||||
|
std::vector<TensorShape>* arg_shapes) {
|
||||||
|
arg_shapes->resize(metadata.args_size());
|
||||||
|
int dynamic_shape_pos = 0;
|
||||||
|
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||||
|
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
||||||
|
// The XLA compiler determines the shape of each constant by inspecting the
|
||||||
|
// value of its corresponding host-memory tensor. As a result, we don't need
|
||||||
|
// to give the compiler graph-inferred shapes for constant arguments.
|
||||||
|
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
||||||
|
PartialTensorShape static_shape(arg.shape());
|
||||||
|
|
||||||
|
TensorShape& shape = (*arg_shapes)[i];
|
||||||
|
if (static_shape.IsFullyDefined()) {
|
||||||
|
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
||||||
|
<< "Too few dynamic shapes";
|
||||||
|
shape = dynamic_shapes[dynamic_shape_pos++];
|
||||||
|
if (!static_shape.IsCompatibleWith(shape)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Mismatch between static and dynamic shape for argument. Static "
|
||||||
|
"shape: ",
|
||||||
|
static_shape.DebugString(),
|
||||||
|
"; dynamic shape: ", shape.DebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Checks we consumed all of the dynamic shapes.
|
||||||
|
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
||||||
|
<< "Too many dynamic shapes";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Function arguments and return values lose their device assignments, so we
|
// Function arguments and return values lose their device assignments, so we
|
||||||
// must recreate them.
|
// must recreate them.
|
||||||
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
|
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
|
||||||
|
|||||||
@ -99,6 +99,15 @@ class TpuCompileOpKernelCommon {
|
|||||||
const std::vector<TensorShape>& arg_shapes,
|
const std::vector<TensorShape>& arg_shapes,
|
||||||
TpuProgramGroupInterface* tpu_program_group) = 0;
|
TpuProgramGroupInterface* tpu_program_group) = 0;
|
||||||
|
|
||||||
|
// Computes shapes for each argument. Uses both the static shape from the
|
||||||
|
// metadata, and the dynamic shapes where the static shape is not
|
||||||
|
// defined. There must be one dynamic_shape for each argument with a
|
||||||
|
// partially defined shape, in index order.
|
||||||
|
static Status ComputeArgumentShapes(
|
||||||
|
const tpu::TPUCompileMetadataProto& metadata,
|
||||||
|
const std::vector<TensorShape>& dynamic_shapes,
|
||||||
|
std::vector<TensorShape>* arg_shapes);
|
||||||
|
|
||||||
// Performs shape inference on `computation`, filling shape_info with operator
|
// Performs shape inference on `computation`, filling shape_info with operator
|
||||||
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
|
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
|
||||||
static Status RunShapeInferenceOnComputation(
|
static Status RunShapeInferenceOnComputation(
|
||||||
|
|||||||
@ -540,43 +540,5 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
|
||||||
std::vector<TensorShape>* arg_shapes) {
|
|
||||||
arg_shapes->resize(metadata.args_size());
|
|
||||||
int dynamic_shape_pos = 0;
|
|
||||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
|
||||||
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
|
||||||
// The XLA compiler determines the shape of each constant by inspecting the
|
|
||||||
// value of its corresponding host-memory tensor. As a result, we don't need
|
|
||||||
// to give the compiler graph-inferred shapes for constant arguments.
|
|
||||||
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
|
||||||
PartialTensorShape static_shape(arg.shape());
|
|
||||||
|
|
||||||
TensorShape& shape = (*arg_shapes)[i];
|
|
||||||
if (static_shape.IsFullyDefined()) {
|
|
||||||
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
|
||||||
} else {
|
|
||||||
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
|
||||||
<< "Too few dynamic shapes";
|
|
||||||
shape = dynamic_shapes[dynamic_shape_pos++];
|
|
||||||
if (!static_shape.IsCompatibleWith(shape)) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Mismatch between static and dynamic shape for argument. Static "
|
|
||||||
"shape: ",
|
|
||||||
static_shape.DebugString(),
|
|
||||||
"; dynamic shape: ", shape.DebugString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Checks we consumed all of the dynamic shapes.
|
|
||||||
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
|
||||||
<< "Too many dynamic shapes";
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -159,14 +159,6 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
|||||||
TPUCompileMetadataProto* metadata,
|
TPUCompileMetadataProto* metadata,
|
||||||
NameAttrList* function_name,
|
NameAttrList* function_name,
|
||||||
std::string* mlir_module);
|
std::string* mlir_module);
|
||||||
|
|
||||||
// Computes shapes for each argument. Uses both the static shape from the
|
|
||||||
// metadata, and the dynamic shapes where the static shape is not
|
|
||||||
// defined. There must be one dynamic_shape for each argument with a
|
|
||||||
// partially defined shape, in index order.
|
|
||||||
Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
|
||||||
std::vector<TensorShape>* arg_shapes);
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|||||||
@ -25,8 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/refcount.h"
|
#include "tensorflow/core/platform/refcount.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
@ -255,10 +253,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
mesh_state_interface));
|
mesh_state_interface));
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
|
|
||||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
|
|
||||||
rmgr, tpu::kCompiledProtoCacheResourceName));
|
|
||||||
|
|
||||||
if (enable_whole_mesh_compilations_) {
|
if (enable_whole_mesh_compilations_) {
|
||||||
// If this is a whole mesh compilation mode, create the compilation cache,
|
// If this is a whole mesh compilation mode, create the compilation cache,
|
||||||
// if missing.
|
// if missing.
|
||||||
@ -282,13 +276,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
if (local_compilation_cache != nullptr) {
|
if (local_compilation_cache != nullptr) {
|
||||||
local_compilation_cache->Unref();
|
local_compilation_cache->Unref();
|
||||||
|
|
||||||
tpu::TpuCompilationCacheLookup* proto_lookup;
|
|
||||||
proto_lookup =
|
|
||||||
new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, rmgr->Create(rmgr->default_container(),
|
|
||||||
tpu::kCompiledProtoCacheResourceName, proto_lookup));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor* ctx_output;
|
Tensor* ctx_output;
|
||||||
|
|||||||
@ -40,12 +40,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
|
||||||
#include "tensorflow/core/platform/tracing.h"
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||||
@ -58,10 +56,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using ::tensorflow::tpu::CompilationCacheEntryRef;
|
|
||||||
using ::tensorflow::tpu::TpuCompilationCacheLookup;
|
|
||||||
using ::tensorflow::tpu::TpuNodeContext;
|
using ::tensorflow::tpu::TpuNodeContext;
|
||||||
|
using CompilationCacheEntryRef = ::tensorflow::tpu::CompilationCacheEntryRef<
|
||||||
|
::tensorflow::tpu::TpuCompilationCacheEntry>;
|
||||||
|
using TpuCompilationCacheLookup =
|
||||||
|
::tensorflow::tpu::TpuCompilationCacheLookup<CompilationCacheEntryRef>;
|
||||||
|
|
||||||
// Looks up the input `key` in the compilation cache, populating
|
// Looks up the input `key` in the compilation cache, populating
|
||||||
// `*rendezvous_key_base` and `*entry`.
|
// `*rendezvous_key_base` and `*entry`.
|
||||||
@ -639,35 +641,28 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
|
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
|
||||||
|
|
||||||
string rendezvous_key_base;
|
string rendezvous_key_base;
|
||||||
std::unique_ptr<CompilationCacheEntryRef> entry_ref;
|
std::unique_ptr<CompilationCacheEntryRef> entry;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
|
GetComputationCacheEntry(context, &rendezvous_key_base, &entry));
|
||||||
|
|
||||||
// Shapes of the inputs and outputs, in xla::Shape form.
|
// Shapes of the inputs and outputs, in xla::Shape form.
|
||||||
tpu::TpuCompilationCacheEntry entry = entry_ref->get();
|
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
|
||||||
const tpu::TpuProgramGroup* tpu_program_group =
|
|
||||||
tensorflow::down_cast<const tpu::TpuProgramGroup*>(
|
|
||||||
entry.tpu_program_group());
|
|
||||||
CHECK_NE(tpu_program_group, nullptr);
|
|
||||||
const int core_index = entry.core_index();
|
|
||||||
const TPUExecutableInfoProto& executable =
|
|
||||||
tpu_program_group->executable_info(core_index);
|
|
||||||
|
|
||||||
xla::Backend* const backend = node_context->backend();
|
xla::Backend* const backend = node_context->backend();
|
||||||
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
xla::TransferManager* const transfer_manager = backend->transfer_manager();
|
||||||
TF_RET_CHECK(context->op_device_context());
|
TF_RET_CHECK(context->op_device_context());
|
||||||
se::Stream* stream = context->op_device_context()->stream();
|
se::Stream* stream = context->op_device_context()->stream();
|
||||||
|
|
||||||
TF_RET_CHECK(executable.input_shapes_size() == 1);
|
TF_RET_CHECK(proto->input_shapes_size() == 1);
|
||||||
|
|
||||||
xla::Shape host_shape(executable.input_shapes(0));
|
xla::Shape host_shape(proto->input_shapes(0));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto variable_update_map,
|
auto variable_update_map,
|
||||||
BuildVariableUpdateMap(executable.variable_indices(),
|
BuildVariableUpdateMap(proto->variable_indices(),
|
||||||
fused_device_var_reads_in_computation_inputs_,
|
fused_device_var_reads_in_computation_inputs_,
|
||||||
fused_device_var_updates_in_computation_outputs_,
|
fused_device_var_updates_in_computation_outputs_,
|
||||||
executable.output_tensor_shapes().size()));
|
proto->output_tensor_shapes().size()));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<InputBuffers> input_buffers,
|
std::unique_ptr<InputBuffers> input_buffers,
|
||||||
BuildComputationInputs(context, host_shape, variable_update_map, backend,
|
BuildComputationInputs(context, host_shape, variable_update_map, backend,
|
||||||
@ -702,9 +697,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
|
|
||||||
// Snapshot the inputs, if a snapshot was requested.
|
// Snapshot the inputs, if a snapshot was requested.
|
||||||
std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
|
std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
|
||||||
if (executable.has_session_module()) {
|
if (proto->has_session_module()) {
|
||||||
hlo_snapshot =
|
hlo_snapshot = std::make_shared<xla::HloSnapshot>(proto->session_module());
|
||||||
std::make_shared<xla::HloSnapshot>(executable.session_module());
|
|
||||||
auto literal =
|
auto literal =
|
||||||
std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
|
std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
|
||||||
transfer_manager->TransferLiteralFromDevice(
|
transfer_manager->TransferLiteralFromDevice(
|
||||||
@ -729,9 +723,9 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
const uint32 rng_seed = GetXLARandomSeed();
|
const uint32 rng_seed = GetXLARandomSeed();
|
||||||
|
|
||||||
std::unique_ptr<xla::DeviceAssignment> device_assignment;
|
std::unique_ptr<xla::DeviceAssignment> device_assignment;
|
||||||
if (executable.has_device_assignment()) {
|
if (proto->has_device_assignment()) {
|
||||||
TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
|
TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
|
||||||
executable.device_assignment()));
|
proto->device_assignment()));
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(4) << "Input buffers after alias resolution: "
|
VLOG(4) << "Input buffers after alias resolution: "
|
||||||
@ -749,24 +743,24 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
|
|||||||
// we free a memory and reassign it to other users while a program is running,
|
// we free a memory and reassign it to other users while a program is running,
|
||||||
// all subsequent writes to the program that could possibly clobber the memory
|
// all subsequent writes to the program that could possibly clobber the memory
|
||||||
// will depend on the program to finish.
|
// will depend on the program to finish.
|
||||||
const TPUHostTransferInfoProto& host_transfer_info =
|
const TPUHostTransferInfoProto* host_transfer_info =
|
||||||
tpu_program_group->host_transfer_info(core_index);
|
entry->get().get_host_transfer_info();
|
||||||
|
const xla::HloProto* hlo_metadata = entry->get().get_hlo_metadata();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
xla::ExecutionOutput output,
|
xla::ExecutionOutput output,
|
||||||
TPUExecute(executable, host_transfer_info,
|
TPUExecute(*proto, *host_transfer_info, *hlo_metadata, std::move(input),
|
||||||
*tpu_program_group->hlo_metadata(core_index), std::move(input),
|
|
||||||
rendezvous_key_base, rng_seed, node_context.get(),
|
rendezvous_key_base, rng_seed, node_context.get(),
|
||||||
device_assignment.get(), context->cancellation_manager(),
|
device_assignment.get(), context->cancellation_manager(),
|
||||||
context, stream, transfer_stream_ptr.get(),
|
context, stream, transfer_stream_ptr.get(),
|
||||||
tpu_program_group->tpu_program(core_index)));
|
entry->get().get_tpu_program()));
|
||||||
stream->ThenRecordEvent(definition_event.get());
|
stream->ThenRecordEvent(definition_event.get());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<OutputBuffers> output_buffers,
|
std::unique_ptr<OutputBuffers> output_buffers,
|
||||||
AllocateOutputTensors(
|
AllocateOutputTensors(context, output.ConsumeResult(),
|
||||||
context, output.ConsumeResult(), executable.output_tensor_shapes(),
|
proto->output_tensor_shapes(), variable_update_map,
|
||||||
variable_update_map, node_context.get(), stream, device_ordinal,
|
node_context.get(), stream, device_ordinal,
|
||||||
input_buffers.get(), definition_event));
|
input_buffers.get(), definition_event));
|
||||||
|
|
||||||
// Transfer the outputs and save the snapshot to disk.
|
// Transfer the outputs and save the snapshot to disk.
|
||||||
if (hlo_snapshot) {
|
if (hlo_snapshot) {
|
||||||
|
|||||||
@ -21,9 +21,6 @@ limitations under the License.
|
|||||||
|
|
||||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||||
|
|
||||||
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
|
||||||
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
// Creates a new TPU program.
|
// Creates a new TPU program.
|
||||||
@ -67,15 +64,6 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
|||||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||||
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
||||||
|
|
||||||
// Check if TPU program has sharding.
|
|
||||||
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
|
||||||
const XLA_TpuProgram* tpu_program);
|
|
||||||
|
|
||||||
// Gets TPU program by sharding type. Return value is valid only when the
|
|
||||||
// `status.status()` returns `OK`.
|
|
||||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
|
||||||
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
|
||||||
|
|
||||||
struct TfTpu_TpuProgramApiFn {
|
struct TfTpu_TpuProgramApiFn {
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
||||||
@ -88,8 +76,6 @@ struct TfTpu_TpuProgramApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|||||||
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
@ -99,71 +98,55 @@ StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
|
|||||||
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
|
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
|
||||||
per_core_variable_indices, device_assignment);
|
per_core_variable_indices, device_assignment);
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void TpuProgramGroup::Initialize(
|
Status CreateTpuProgramGroup(
|
||||||
absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
|
absl::Span<XLA_TpuProgram* const> xla_tpu_programs,
|
||||||
|
TpuProgramGroupInterface* tpu_program_group_interface) {
|
||||||
CHECK_GT(xla_tpu_programs.size(), 0);
|
CHECK_GT(xla_tpu_programs.size(), 0);
|
||||||
set_tpu_programs(xla_tpu_programs);
|
TpuProgramGroup* tpu_program_group =
|
||||||
|
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||||
|
CHECK_NE(tpu_program_group, nullptr);
|
||||||
|
tpu_program_group->set_tpu_programs(xla_tpu_programs);
|
||||||
|
|
||||||
std::vector<bool> may_modify_variables_array(xla_tpu_programs.size(), false);
|
// TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
|
||||||
std::vector<TPUExecutableInfoProto> executable_infos(xla_tpu_programs.size());
|
bool may_modify_variables;
|
||||||
std::vector<TPUHostTransferInfoProto> host_transfer_infos(
|
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_programs[0],
|
||||||
xla_tpu_programs.size());
|
&may_modify_variables);
|
||||||
std::vector<xla::HloProto> hlo_metadatas(xla_tpu_programs.size());
|
tpu_program_group->set_may_modify_variables(
|
||||||
for (size_t i = 0; i < xla_tpu_programs.size(); ++i) {
|
std::vector<bool>(1, may_modify_variables));
|
||||||
const XLA_TpuProgram* xla_tpu_program = xla_tpu_programs[i];
|
|
||||||
bool may_modify_variables;
|
|
||||||
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
|
|
||||||
xla_tpu_program, &may_modify_variables);
|
|
||||||
may_modify_variables_array[i] = may_modify_variables;
|
|
||||||
|
|
||||||
TpuSerializedProto serialized_executable_info;
|
TpuSerializedProto serialized_executable_info;
|
||||||
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
|
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
|
||||||
xla_tpu_program, &serialized_executable_info);
|
xla_tpu_programs[0], &serialized_executable_info);
|
||||||
TPUExecutableInfoProto executable_info =
|
TPUExecutableInfoProto executable_info =
|
||||||
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
|
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
|
||||||
serialized_executable_info);
|
serialized_executable_info);
|
||||||
executable_infos[i] = executable_info;
|
tpu_program_group->set_executable_info(executable_info);
|
||||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
|
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
|
||||||
|
|
||||||
TPUHostTransferInfoProto host_transfer_info;
|
TPUHostTransferInfoProto host_transfer_info;
|
||||||
TpuSerializedProto serialized_host_transfer_info;
|
TpuSerializedProto serialized_host_transfer_info;
|
||||||
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
|
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
|
||||||
xla_tpu_program, &serialized_host_transfer_info);
|
xla_tpu_programs[0], &serialized_host_transfer_info);
|
||||||
if (serialized_host_transfer_info.size > 0) {
|
if (serialized_host_transfer_info.size > 0) {
|
||||||
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
|
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
|
||||||
serialized_host_transfer_info);
|
serialized_host_transfer_info);
|
||||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
|
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
|
||||||
}
|
|
||||||
host_transfer_infos[i] = host_transfer_info;
|
|
||||||
|
|
||||||
TpuSerializedProto serialized_hlo_metadata;
|
|
||||||
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_program,
|
|
||||||
&serialized_hlo_metadata);
|
|
||||||
xla::HloProto hlo_metadata =
|
|
||||||
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
|
|
||||||
hlo_metadatas[i] = hlo_metadata;
|
|
||||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
|
|
||||||
}
|
}
|
||||||
|
tpu_program_group->set_host_transfer_info(host_transfer_info);
|
||||||
|
|
||||||
may_modify_variables_ = may_modify_variables_array;
|
TpuSerializedProto serialized_hlo_metadata;
|
||||||
executable_infos_ = executable_infos;
|
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_programs[0],
|
||||||
host_transfer_infos_ = host_transfer_infos;
|
&serialized_hlo_metadata);
|
||||||
hlo_metadatas_ = hlo_metadatas;
|
xla::HloProto hlo_metadata =
|
||||||
RefreshHloMetadatasPtrs();
|
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
|
||||||
|
tpu_program_group->set_hlo_metadata(hlo_metadata);
|
||||||
|
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuProgramGroup::has_sharding_program() const {
|
} // namespace
|
||||||
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
|
|
||||||
if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
|
|
||||||
|
|
||||||
int64_t TpuProgramGroup::program_size() const {
|
int64_t TpuProgramGroup::program_size() const {
|
||||||
int64_t total_size = 0;
|
int64_t total_size = 0;
|
||||||
@ -218,6 +201,12 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
|||||||
TF_RET_CHECK(per_core_output_shapes.size() ==
|
TF_RET_CHECK(per_core_output_shapes.size() ==
|
||||||
per_core_variable_indices.size());
|
per_core_variable_indices.size());
|
||||||
|
|
||||||
|
// TODO(henrytan): add an interface to TpuProgramGroupInterface to set
|
||||||
|
// may_modify_variables.
|
||||||
|
TpuProgramGroup* tpu_program_group =
|
||||||
|
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||||
|
tpu_program_group->may_modify_variables_ = may_modify_variables;
|
||||||
|
|
||||||
// With shardable input/output pairs, XLA could generate separate
|
// With shardable input/output pairs, XLA could generate separate
|
||||||
// sharding/unsharding programs along with the main program. The
|
// sharding/unsharding programs along with the main program. The
|
||||||
// sharding/unsharding programs will be in nested entries of the AOT
|
// sharding/unsharding programs will be in nested entries of the AOT
|
||||||
@ -232,20 +221,17 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
|||||||
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
|
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
|
||||||
xla_tpu_programs.size() == metadata.num_cores_per_replica());
|
xla_tpu_programs.size() == metadata.num_cores_per_replica());
|
||||||
|
|
||||||
// TODO(henrytan): add an interface to TpuProgramGroupInterface to set
|
TF_RETURN_IF_ERROR(
|
||||||
// may_modify_variables.
|
CreateTpuProgramGroup(xla_tpu_programs, tpu_program_group));
|
||||||
TpuProgramGroup* tpu_program_group =
|
|
||||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
|
||||||
tpu_program_group->Initialize(xla_tpu_programs);
|
|
||||||
tpu_program_group->may_modify_variables_ = may_modify_variables;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
|
TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
|
||||||
: may_modify_variables_(std::move(other.may_modify_variables_)),
|
: may_modify_variables_(std::move(other.may_modify_variables_)),
|
||||||
|
host_compute_metadata_(std::move(other.host_compute_metadata_)),
|
||||||
tpu_programs_(std::move(other.tpu_programs_)),
|
tpu_programs_(std::move(other.tpu_programs_)),
|
||||||
executable_infos_(std::move(other.executable_infos_)),
|
executable_info_(std::move(other.executable_info_)),
|
||||||
host_transfer_infos_(std::move(other.host_transfer_infos_)),
|
host_transfer_info_(std::move(other.host_transfer_info_)),
|
||||||
hlo_metadatas_(std::move(other.hlo_metadatas_)) {
|
hlo_metadatas_(std::move(other.hlo_metadatas_)) {
|
||||||
RefreshHloMetadatasPtrs();
|
RefreshHloMetadatasPtrs();
|
||||||
}
|
}
|
||||||
@ -262,12 +248,6 @@ absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
|
|||||||
return hlo_metadatas_ptrs_;
|
return hlo_metadatas_ptrs_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
|
|
||||||
CHECK_GE(index, 0);
|
|
||||||
CHECK_LT(index, hlo_metadatas_ptrs_.size());
|
|
||||||
return hlo_metadatas_ptrs_[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
void TpuProgramGroup::RefreshHloMetadatasPtrs() {
|
void TpuProgramGroup::RefreshHloMetadatasPtrs() {
|
||||||
hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
|
hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
|
||||||
for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
|
for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
|
||||||
@ -282,47 +262,6 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<bool>& TpuProgramGroup::may_modify_variables() const {
|
|
||||||
return may_modify_variables_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TpuProgramGroup::set_may_modify_variables(
|
|
||||||
const std::vector<bool>& may_modify_variables) {
|
|
||||||
may_modify_variables_ = may_modify_variables;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
|
|
||||||
return tpu_programs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
|
|
||||||
CHECK_GE(index, 0);
|
|
||||||
CHECK_LT(index, tpu_programs_.size());
|
|
||||||
return tpu_programs_[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
void TpuProgramGroup::set_tpu_programs(
|
|
||||||
absl::Span<XLA_TpuProgram* const> tpu_programs) {
|
|
||||||
tpu_programs_.resize(tpu_programs.size());
|
|
||||||
for (size_t i = 0; i < tpu_programs.size(); ++i) {
|
|
||||||
tpu_programs_[i] = tpu_programs[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
|
|
||||||
int index) const {
|
|
||||||
CHECK_GE(index, 0);
|
|
||||||
CHECK_LT(index, executable_infos_.size());
|
|
||||||
return executable_infos_[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
|
|
||||||
int index) const {
|
|
||||||
CHECK_GE(index, 0);
|
|
||||||
CHECK_LT(index, host_transfer_infos_.size());
|
|
||||||
return host_transfer_infos_[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
Status TpuProgramGroup::CompileAndBuild(
|
Status TpuProgramGroup::CompileAndBuild(
|
||||||
const TpuCompilationRequestProto& compilation_request,
|
const TpuCompilationRequestProto& compilation_request,
|
||||||
@ -348,27 +287,15 @@ Status TpuProgramGroup::CompileAndBuild(
|
|||||||
TF_RET_CHECK(count == 1 ||
|
TF_RET_CHECK(count == 1 ||
|
||||||
count == compilation_request.metadata().num_cores_per_replica());
|
count == compilation_request.metadata().num_cores_per_replica());
|
||||||
|
|
||||||
VLOG(1) << "Initialize TpuProgramGroup.";
|
VLOG(1) << "CreateTpuProgramGroup";
|
||||||
TpuProgramGroup* tpu_program_group =
|
Status serialize_status =
|
||||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
CreateTpuProgramGroup(absl::MakeConstSpan(&xla_tpu_programs[0], count),
|
||||||
tpu_program_group->Initialize(
|
tpu_program_group_interface);
|
||||||
absl::MakeConstSpan(&xla_tpu_programs[0], count));
|
VLOG(1) << absl::StrCat("Run CreateTpuProgramGroup completed. StatusCode: ",
|
||||||
|
serialize_status.code());
|
||||||
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
|
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
|
||||||
return status.status();
|
return serialize_status;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
|
|
||||||
TpuProgramShardingType sharding_type) const {
|
|
||||||
std::vector<XLA_TpuProgram*> tpu_programs;
|
|
||||||
tpu_programs.reserve(tpu_programs_.size());
|
|
||||||
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
|
|
||||||
if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
|
|
||||||
tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn(
|
|
||||||
tpu_programs_[i], sharding_type));
|
|
||||||
CHECK_NE(tpu_programs[i], nullptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tpu_programs;
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -102,16 +102,11 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
|||||||
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
||||||
TpuProgramGroupInterface* tpu_program_group_interface);
|
TpuProgramGroupInterface* tpu_program_group_interface);
|
||||||
|
|
||||||
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
|
|
||||||
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
|
|
||||||
|
|
||||||
TpuProgramGroup() = default;
|
TpuProgramGroup() = default;
|
||||||
TpuProgramGroup(TpuProgramGroup&& other);
|
TpuProgramGroup(TpuProgramGroup&& other);
|
||||||
TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
|
TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
|
||||||
|
|
||||||
bool has_sharding_program() const override;
|
size_t program_count() const override { return tpu_programs_.size(); }
|
||||||
|
|
||||||
size_t program_count() const override;
|
|
||||||
|
|
||||||
int64_t program_size() const override;
|
int64_t program_size() const override;
|
||||||
|
|
||||||
@ -122,29 +117,58 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
|||||||
Status LogCompilationStats(const TpuCompilationCacheKey& key,
|
Status LogCompilationStats(const TpuCompilationCacheKey& key,
|
||||||
absl::Duration duration) override;
|
absl::Duration duration) override;
|
||||||
|
|
||||||
const std::vector<bool>& may_modify_variables() const override;
|
const std::vector<bool>& may_modify_variables() const override {
|
||||||
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
|
return may_modify_variables_;
|
||||||
|
}
|
||||||
|
void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
|
||||||
|
may_modify_variables_ = may_modify_variables;
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
|
const tf2xla::HostComputeMetadata& host_compute_metadata() const {
|
||||||
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
|
return host_compute_metadata_;
|
||||||
const XLA_TpuProgram* tpu_program(int index) const;
|
}
|
||||||
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
|
void set_host_compute_metadata(
|
||||||
|
const tf2xla::HostComputeMetadata& host_compute_metadata) {
|
||||||
|
host_compute_metadata_ = host_compute_metadata;
|
||||||
|
}
|
||||||
|
|
||||||
const TPUExecutableInfoProto& executable_info(int index) const;
|
const std::vector<XLA_TpuProgram*>& tpu_programs() const {
|
||||||
|
return tpu_programs_;
|
||||||
|
}
|
||||||
|
void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs) {
|
||||||
|
tpu_programs_.resize(tpu_programs.size());
|
||||||
|
for (size_t i = 0; i < tpu_programs.size(); ++i) {
|
||||||
|
tpu_programs_[i] = tpu_programs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TPUExecutableInfoProto& executable_info() const {
|
||||||
|
return executable_info_;
|
||||||
|
}
|
||||||
|
void set_executable_info(const TPUExecutableInfoProto& executable_info) {
|
||||||
|
executable_info_ = executable_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TPUHostTransferInfoProto& host_transfer_info() const {
|
||||||
|
return host_transfer_info_;
|
||||||
|
}
|
||||||
|
void set_host_transfer_info(
|
||||||
|
const TPUHostTransferInfoProto& host_transfer_info) {
|
||||||
|
host_transfer_info_ = host_transfer_info;
|
||||||
|
}
|
||||||
|
|
||||||
const TPUHostTransferInfoProto& host_transfer_info(int index) const;
|
|
||||||
void set_hlo_metadata(const xla::HloProto& hlo_metadata);
|
void set_hlo_metadata(const xla::HloProto& hlo_metadata);
|
||||||
const xla::HloProto* hlo_metadata(int index) const;
|
|
||||||
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void RefreshHloMetadatasPtrs();
|
void RefreshHloMetadatasPtrs();
|
||||||
|
|
||||||
std::vector<bool> may_modify_variables_;
|
std::vector<bool> may_modify_variables_;
|
||||||
|
tf2xla::HostComputeMetadata host_compute_metadata_;
|
||||||
|
|
||||||
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
|
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
|
||||||
std::vector<TPUExecutableInfoProto> executable_infos_;
|
TPUExecutableInfoProto executable_info_;
|
||||||
std::vector<TPUHostTransferInfoProto> host_transfer_infos_;
|
TPUHostTransferInfoProto host_transfer_info_;
|
||||||
|
|
||||||
// To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
|
// To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
|
||||||
// signature, we store HloProto values in hlo_metadatas_ when
|
// signature, we store HloProto values in hlo_metadatas_ when
|
||||||
|
|||||||
@ -20,8 +20,6 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/time/time.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -36,16 +34,13 @@ class TpuProgramGroupInterface {
|
|||||||
public:
|
public:
|
||||||
virtual ~TpuProgramGroupInterface() = default;
|
virtual ~TpuProgramGroupInterface() = default;
|
||||||
|
|
||||||
// Check if whether sharding/unsharding program exists.
|
|
||||||
virtual bool has_sharding_program() const = 0;
|
|
||||||
|
|
||||||
// Computes program count.
|
// Computes program count.
|
||||||
virtual size_t program_count() const = 0;
|
virtual size_t program_count() const = 0;
|
||||||
|
|
||||||
// Computes total program size.
|
// Computes total program size.
|
||||||
virtual int64_t program_size() const = 0;
|
virtual int64_t program_size() const = 0;
|
||||||
|
|
||||||
// Unloads and destroys safely TPU programs.
|
// Unloads and destroys safely Tpu programs.
|
||||||
virtual void UnloadAndDestroyPrograms() = 0;
|
virtual void UnloadAndDestroyPrograms() = 0;
|
||||||
|
|
||||||
// Logs program memory summary.
|
// Logs program memory summary.
|
||||||
|
|||||||
@ -64,8 +64,6 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo);
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo);
|
||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata);
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata);
|
||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
|
|
||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
|
|
||||||
|
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user