From 3cf7683cfea4277223b940bfb5563efd541badd4 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Mon, 3 Aug 2020 16:16:10 -0700 Subject: [PATCH] Refactor `TpuCompilationCacheEntry` interface to return `TpuProgramGroupInterface` and `core_index` and makes CacheEntry less transparent and move application specific logics outside of cache. PiperOrigin-RevId: 324705343 Change-Id: I9dc421df069dbe7dc9bb57695f06e8b636fbc945 --- tensorflow/core/tpu/kernels/BUILD | 28 ++- .../kernels/tpu_compilation_cache_entry.cc | 54 +++++ .../tpu/kernels/tpu_compilation_cache_entry.h | 26 ++- .../tpu_compilation_cache_entry_impl.h | 94 +++++++++ .../kernels/tpu_compilation_cache_external.cc | 53 ++--- .../kernels/tpu_compilation_cache_external.h | 12 ++ .../tpu_compilation_cache_interface.cc | 144 ++----------- .../kernels/tpu_compilation_cache_interface.h | 111 ++++++---- .../tpu_compilation_cache_local_lookup.cc | 43 +++- .../tpu_compilation_cache_local_lookup.h | 13 +- .../kernels/tpu_compilation_cache_lookup.h | 18 +- .../core/tpu/kernels/tpu_compile_op_common.cc | 40 ++++ .../core/tpu/kernels/tpu_compile_op_common.h | 9 + .../tpu/kernels/tpu_compile_op_support.cc | 38 ---- .../core/tpu/kernels/tpu_compile_op_support.h | 8 - .../core/tpu/kernels/tpu_configuration_ops.cc | 13 -- tensorflow/core/tpu/kernels/tpu_execute_op.cc | 58 +++--- .../core/tpu/kernels/tpu_program_c_api.h | 14 -- .../core/tpu/kernels/tpu_program_group.cc | 189 ++++++------------ .../core/tpu/kernels/tpu_program_group.h | 58 ++++-- .../tpu/kernels/tpu_program_group_interface.h | 7 +- tensorflow/core/tpu/tpu_library_init_fns.inc | 2 - 22 files changed, 523 insertions(+), 509 deletions(-) create mode 100644 tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc create mode 100644 tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 1336f52ed34..3b7d0e09c08 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -92,8 +92,6 @@ tf_kernel_library( deps = [ ":tpu_compilation_cache_factory", ":tpu_compilation_cache_interface", - ":tpu_compilation_cache_local_lookup", - ":tpu_compilation_cache_lookup", ":tpu_mesh_state_interface", ":tpu_op_consts", "//tensorflow/c:tf_status", @@ -210,14 +208,30 @@ cc_library( cc_library( name = "tpu_compilation_cache_entry", + srcs = ["tpu_compilation_cache_entry.cc"], hdrs = [ "tpu_compilation_cache_entry.h", ], deps = [ + ":compiled_subgraph", + ":tpu_compilation_cache_proto_cc", ":tpu_executable_info_proto_cc", - ":tpu_program_group_interface", + ":tpu_program_group", "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:framework", "//tensorflow/core/lib/core:refcount", + "//tensorflow/core/platform:casts", + ], +) + +cc_library( + name = "tpu_compilation_cache_entry_impl", + srcs = [], + hdrs = ["tpu_compilation_cache_entry_impl.h"], + deps = [ + ":compiled_subgraph", + ":tpu_compilation_cache_interface", + ":tpu_executable_info_proto_cc", ], ) @@ -288,8 +302,6 @@ cc_library( "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/core/lib/core:status", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", ], ) @@ -329,7 +341,6 @@ cc_library( hdrs = ["tpu_compilation_cache_interface.h"], deps = [ ":compiled_subgraph", - ":tpu_compilation_cache_entry", ":tpu_compilation_cache_key", ":tpu_compilation_cache_proto_cc", ":tpu_compilation_metrics_hdrs", @@ -361,6 +372,7 @@ cc_library( deps = [ ":compiled_subgraph", ":tpu_compilation_cache_entry", + ":tpu_compilation_cache_entry_impl", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_key", ":tpu_compilation_cache_proto_cc", @@ -370,7 +382,6 @@ cc_library( ":tpu_compile_op_support", ":tpu_mesh_state_interface", ":tpu_op_consts", - ":tpu_program_c_api_hdrs", ":tpu_program_group", ":tpu_util", ":trace_util_hdrs", @@ -380,10 +391,10 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -604,7 +615,6 @@ cc_library( deps = [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_external", - ":tpu_compilation_cache_interface", ":tpu_compilation_cache_local_lookup", ":tpu_compilation_cache_lookup", ":tpu_executable_info_proto_cc", diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc new file mode 100644 index 00000000000..73f55853306 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.cc @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" + +#include "tensorflow/core/platform/casts.h" + +namespace tensorflow { +namespace tpu { + +TpuCompilationCacheEntry::TpuCompilationCacheEntry( + const TpuProgramGroupInterface* tpu_program_group, int core_index) + : tpu_program_group_( + tensorflow::down_cast(tpu_program_group)), + core_index_(core_index) {} + +// Constructor for an empty entry. +TpuCompilationCacheEntry::TpuCompilationCacheEntry() + : tpu_program_group_(nullptr) {} + +const TPUExecutableInfoProto* TpuCompilationCacheEntry::get_executable_info() + const { + return &(tpu_program_group_->executable_info()); +} + +const TPUHostTransferInfoProto* +TpuCompilationCacheEntry::get_host_transfer_info() const { + return &(tpu_program_group_->host_transfer_info()); +} + +const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const { + return tpu_program_group_->hlo_metadatas()[core_index_]; +} + +// TODO(henrytan,jiawenhao): When should we expect more than one +// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then. +const XLA_TpuProgram* TpuCompilationCacheEntry::get_tpu_program() const { + CHECK_EQ(tpu_program_group_->program_count(), 1); + return tpu_program_group_->tpu_programs()[core_index_]; +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h index 832d76bfceb..b3766b8b4dd 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h @@ -18,32 +18,30 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group.h" namespace tensorflow { namespace tpu { -// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to -// fetch a TPU program for a given TPU core index. +// A version of `CompilationCacheEntry` to access Tpu binary program +// `XLA_TpuProgram`. class TpuCompilationCacheEntry { public: explicit TpuCompilationCacheEntry( - const TpuProgramGroupInterface* tpu_program_group, int core_index) - : tpu_program_group_(tpu_program_group), core_index_(core_index) {} - + const TpuProgramGroupInterface* tpu_program_group, int core_index); // Constructor for an empty entry. - TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {} - - const TpuProgramGroupInterface* tpu_program_group() const { - return tpu_program_group_; - } - - int core_index() const { return core_index_; } + TpuCompilationCacheEntry(); + const TPUExecutableInfoProto* get_executable_info() const; + const TPUHostTransferInfoProto* get_host_transfer_info() const; + const xla::HloProto* get_hlo_metadata() const; + // TODO(henrytan): maybe nicer to return C++ wrapper of `XLA_TpuProgram` + const XLA_TpuProgram* get_tpu_program() const; private: - const TpuProgramGroupInterface* tpu_program_group_; + const TpuProgramGroup* tpu_program_group_; int core_index_; }; + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h new file mode 100644 index 00000000000..0632d9a163f --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ +#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +namespace tensorflow { +namespace tpu { +// Wrapper for a cache entry that holds a reference to the entry until the +// wrapper is deleted. This wrapper is the concrete type of +// CompilationCacheEntryRef returned by Lookup. +template +class CompilationCacheEntryRefImpl + : public CompilationCacheEntryRef { + public: + CompilationCacheEntryRefImpl(TpuCompilationCacheInterface* parent, + CompiledSubgraph* entry, int index); + ~CompilationCacheEntryRefImpl() override; + Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override; + + protected: + TpuCompilationCacheInterface* parent_; // Not owned. + // A reference to entry_ is acquired in the constructor and released via + // parent->DiscardEntryRefs in the destructor. + CompiledSubgraph* entry_; + // The index of the program in entry_ that is returned by the get method. + int index_; +}; +template +CompilationCacheEntryRefImpl::CompilationCacheEntryRefImpl( + TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) + : parent_(parent), entry_(entry), index_(index) { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + entry_->Ref(); + } else { + // This is a sharding/unsharding entry nested in a main entry. Only + // refcount the main entry. + entry_->main_entry->Ref(); + } +} +template +CompilationCacheEntryRefImpl::~CompilationCacheEntryRefImpl() { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + parent_->DiscardEntryRefs({entry_}); + } else { + parent_->DiscardEntryRefs({entry_->main_entry}); + } +} +template +Status CompilationCacheEntryRefImpl::ToSubEntryRef( + CompilationCacheFetchTarget fetch_target) { + CompiledSubgraph* target = nullptr; + switch (fetch_target) { + case CompilationCacheFetchTarget::MAIN: + target = entry_; + break; + case CompilationCacheFetchTarget::SHARDING: + target = entry_->sharding_entry.get(); + break; + case CompilationCacheFetchTarget::UNSHARDING: + target = entry_->unsharding_entry.get(); + break; + default: + return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); + } + if (target == nullptr) { + // Cache entry does not have an unsharding subentry. Unref and replace + // with nullptr. + parent_->DiscardEntryRefs({entry_}); + } + // Otherwise, since the refcount is always on the main entry, we don't + // need ref/unref. + entry_ = target; + return Status::OK(); +} +} // namespace tpu +} // namespace tensorflow +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_IMPL_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc index 80010d70cd4..b4b18d1743b 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc @@ -16,18 +16,15 @@ limitations under the License. #include -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/kernels/trace_util.h" @@ -51,22 +48,23 @@ void PopulateEntry(const std::string& key, CompiledSubgraph* entry, entry->tpu_program_group = absl::make_unique(std::move(tpu_program_group)); entry->initialized = true; - - if (entry->initialization_status.ok()) { - // Compute the entries total size once all members are initialized. - entry->total_size = entry->ComputeTotalSize(); - } -} - -std::unique_ptr CreateAndInitializeCompiledSubgraph( - CompiledSubgraph* main_entry) { - auto entry = absl::make_unique(); - entry->main_entry = main_entry; - entry->tpu_program_group = absl::make_unique(); - return entry; } } // namespace +TpuCompilationCacheExternal::EntryRefImpl::EntryRefImpl( + TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) + : CompilationCacheEntryRefImpl(parent, entry, + index) {} + +TpuCompilationCacheEntry TpuCompilationCacheExternal::EntryRefImpl::get() { + if (entry_ == nullptr) { + // Create an empty entry if the entry is nullptr. This corresponds to + // non-existing sharding/unsharding entries. + return TpuCompilationCacheEntry(); + } + return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_); +} + CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( const string& key, const std::function& initialize_program, @@ -75,6 +73,7 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( main_entry->parent = this; main_entry->subgraph_key = key; main_entry->uid = get_uid(); + // TODO(henrytan): implement TpuCompilationCacheKey.debug_string. main_entry->cache_entry_debug_string = subgraph_key.prefix; VLOG(1) << "Cache Initializing Entry Session Debug " << main_entry->cache_entry_debug_string; @@ -113,29 +112,17 @@ CompiledSubgraph* TpuCompilationCacheExternal::InitializeEntry( std::pair(main_entry->uid, main_entry)); CHECK(uid_inserted.second); - if (tpu_program_group.has_sharding_program()) { - main_entry->sharding_entry = - CreateAndInitializeCompiledSubgraph(main_entry); - TpuProgramGroup sharding_programs; - sharding_programs.Initialize( - tpu_program_group.tpu_programs(TpuProgramShardingType::kSharding)); - PopulateEntry(key, main_entry->sharding_entry.get(), - std::move(sharding_programs)); - - main_entry->unsharding_entry = - CreateAndInitializeCompiledSubgraph(main_entry); - TpuProgramGroup unsharding_programs; - unsharding_programs.Initialize( - tpu_program_group.tpu_programs(TpuProgramShardingType::kUnsharding)); - PopulateEntry(key, main_entry->unsharding_entry.get(), - std::move(unsharding_programs)); + if (initialization_status.ok()) { + // Compute the entries total size once all members are initialized. + main_entry->total_size = tpu_program_group.program_size(); } + // TODO(henrytan): handle sharding/unsharding. PopulateEntry(key, main_entry, std::move(tpu_program_group)); for (int64 i = 0; i < main_entry->proto_key.size(); ++i) { auto entry_inserted = entries_by_proto_key_.insert( - std::pair>( + std::pair>( main_entry->proto_key[i], std::make_pair(main_entry, i))); CHECK(entry_inserted.second); } diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h index 51b5ffbed0d..86615b15d4c 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_impl.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" @@ -45,6 +46,17 @@ namespace tpu { class TpuCompilationCacheExternal : public TpuCompilationCacheInterface { public: + using Status = ::stream_executor::port::Status; + + class EntryRefImpl + : public CompilationCacheEntryRefImpl { + public: + EntryRefImpl(TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, + int index); + + TpuCompilationCacheEntry get() override; + }; + explicit TpuCompilationCacheExternal(int64 max_cache_size) : TpuCompilationCacheInterface(max_cache_size) {} diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc index 4cd2b864203..9e1aedf92ce 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc @@ -38,77 +38,10 @@ void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) { entries_.push_back(entry); } -std::string TpuCompilationCacheInterface::RefHolder::DebugString() const { +string TpuCompilationCacheInterface::RefHolder::DebugString() const { return "TpuCompilationCacheRefHolder"; } -CompilationCacheEntryRef::CompilationCacheEntryRef() - : parent_(nullptr), entry_(nullptr), index_(0) {} - -CompilationCacheEntryRef::CompilationCacheEntryRef( - TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index) - : parent_(parent), entry_(entry), index_(index) { - if (entry_ == nullptr) { - return; - } - if (entry_->main_entry == nullptr) { - entry_->Ref(); - } else { - // This is a sharding/unsharding entry nested in a main entry. Only - // refcount the main entry. - entry_->main_entry->Ref(); - } -} - -CompilationCacheEntryRef::~CompilationCacheEntryRef() { - if (entry_ == nullptr) { - return; - } - if (entry_->main_entry == nullptr) { - parent_->DiscardEntryRefs({entry_}); - } else { - parent_->DiscardEntryRefs({entry_->main_entry}); - } -} - -TpuCompilationCacheEntry CompilationCacheEntryRef::get() { - if (entry_ == nullptr) { - // Create an empty entry if the entry is nullptr. This corresponds to - // non-existing sharding/unsharding entries. - return TpuCompilationCacheEntry(); - } - - return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_); -} - -Status CompilationCacheEntryRef::ToSubEntryRef( - CompilationCacheFetchTarget fetch_target) { - CompiledSubgraph* target = nullptr; - switch (fetch_target) { - case CompilationCacheFetchTarget::MAIN: - target = entry_; - break; - case CompilationCacheFetchTarget::SHARDING: - target = entry_->sharding_entry.get(); - break; - case CompilationCacheFetchTarget::UNSHARDING: - target = entry_->unsharding_entry.get(); - break; - default: - return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); - } - - if (target == nullptr) { - // Cache entry does not have an unsharding subentry. Unref and replace - // with nullptr. - parent_->DiscardEntryRefs({entry_}); - } - // Otherwise, since the refcount is always on the main entry, we don't - // need ref/unref. - entry_ = target; - return Status::OK(); -} - TpuCompilationCacheInterface::TpuCompilationCacheInterface(int64 max_cache_size) : max_cache_size_(max_cache_size) { CHECK_GE(max_cache_size_, 0); @@ -223,7 +156,7 @@ void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) { entry->Unref(); } -size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) { +size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) { auto erased = cache_.erase(key); TpuCompilationMetrics::SetCacheEntryCount(cache_.size()); @@ -263,7 +196,7 @@ CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef( } erased = entries_by_uid_.erase(entry->uid); CHECK_EQ(erased, 1); - for (const std::string& key : entry->proto_key) { + for (const string& key : entry->proto_key) { erased = entries_by_proto_key_.erase(key); CHECK_EQ(erased, 1); } @@ -336,10 +269,10 @@ void TpuCompilationCacheInterface::LookupEntryMarkedForEviction( } } -void TpuCompilationCacheInterface::InsertEntry(const std::string& key, +void TpuCompilationCacheInterface::InsertEntry(const string& key, CompiledSubgraph* entry) { auto cache_inserted = - cache_.insert(std::pair(key, entry)); + cache_.insert(std::pair(key, entry)); CHECK(cache_inserted.second); TpuCompilationMetrics::SetCacheEntryCount(cache_.size()); @@ -362,8 +295,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent( const TpuCompilationCacheKey& subgraph_key, const SessionMetadata* session_metadata, CompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector* proto_key, - std::vector* may_modify_variables, + std::vector* proto_key, std::vector* may_modify_variables, absl::Span* hlo_metadatas, const std::function& compile_function) { std::vector removed_entries; @@ -376,7 +308,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsent( return status; } -std::string TpuCompilationCacheInterface::FindCacheKey( +string TpuCompilationCacheInterface::FindCacheKey( const TpuCompilationCacheKey& subgraph_key) { if (!subgraph_key.has_guaranteed_const) { return subgraph_key.prefix; @@ -399,8 +331,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( const TpuCompilationCacheKey& subgraph_key, const SessionMetadata* session_metadata, CompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector* proto_key, - std::vector* may_modify_variables, + std::vector* proto_key, std::vector* may_modify_variables, std::vector* removed_entries, absl::Span* hlo_metadatas, const std::function& compile_function) { @@ -414,18 +345,17 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( // for the lifetime of the object, see InitializeEntry() call below. absl::MutexLock lock(&mu_); - std::string cache_key = FindCacheKey(subgraph_key); + string cache_key = FindCacheKey(subgraph_key); auto iter = cache_.find(cache_key); bool is_new_key = iter == cache_.end(); - const std::string session_name = - tpu::SessionNameFromMetadata(session_metadata); + const string session_name = tpu::SessionNameFromMetadata(session_metadata); if (is_new_key) { cache_key = subgraph_key.ToString(); TpuCompilationMetrics::IncrementCacheLookupCount( /*is_cache_hit=*/false, session_name); - const std::string msg = + const string msg = strings::StrCat("TPU host compilation cache miss: cache_key(", cache_key, "), session_name(", session_name, ")"); TRACESTRING(msg); @@ -434,7 +364,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( // Check if caller has disabled compilation. Set using // internal::ScopedTpuCompileDisabler. if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) { - const std::string error_msg = strings::StrCat( + const string error_msg = strings::StrCat( "[TpuCompilationDisabled]: Compilation cache miss, but compilation " "disabled, session_name(", session_name, ") Debug String: ", subgraph_key.debug_string); @@ -473,7 +403,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( } else { TpuCompilationMetrics::IncrementCacheLookupCount( /*is_cache_hit=*/true, session_name); - const std::string msg = + const string msg = strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key, "), session_name(", session_name, ")"); TRACESTRING(msg); @@ -536,8 +466,8 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( return entry->initialization_status; } -Status TpuCompilationCacheInterface::GetKeysFromUid( - int64 uid, std::vector* keys) { +Status TpuCompilationCacheInterface::GetKeysFromUid(int64 uid, + std::vector* keys) { keys->clear(); absl::MutexLock lock(&mu_); @@ -549,49 +479,5 @@ Status TpuCompilationCacheInterface::GetKeysFromUid( return Status::OK(); } -Status TpuCompilationCacheInterface::Lookup( - int64 uid, int proto_index, - std::unique_ptr* entry) { - entry->reset(); - - profiler::TraceMe proto_lookup_traceme( - "TPU compilation cache proto lookup by uid", - /*level=*/2); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_uid_.find(uid); - if (iter == entries_by_uid_.end()) { - return errors::NotFound("No subgraph found for uid ", uid); - } - CompiledSubgraph* cache_entry = iter->second; - if (proto_index < 0 || - proto_index >= cache_entry->tpu_program_group->program_count()) { - return errors::NotFound("No proto found for core index ", proto_index, - " in subgraph with uid ", uid); - } - *entry = absl::make_unique(this, cache_entry, - proto_index); - return Status::OK(); -} - -Status TpuCompilationCacheInterface::Lookup( - const std::string& proto_key, - std::unique_ptr* entry) { - entry->reset(); - - profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", - /*level=*/2); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_proto_key_.find(proto_key); - if (iter == entries_by_proto_key_.end()) { - return errors::NotFound("No proto found for key ", proto_key); - } - CompiledSubgraph* cache_entry = iter->second.first; - int proto_index = iter->second.second; - *entry = absl::make_unique(this, cache_entry, - proto_index); - return Status::OK(); -} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h index 7b206fb1cf4..cde6467b7af 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h" #include "tensorflow/core/tpu/kernels/trace_util.h" @@ -49,20 +48,18 @@ class CompilationRefHolder : public ResourceBase { ~CompilationRefHolder() override = default; }; -// Wrapper for a cache entry returned by all the TpuCompilationCacheInterface -// `Lookup` methods, and ensures the underlying proto is not garbage-collected -// until the client discards the ptr. +// Base class for a reference to a cached tpu program. A unique_ptr to a +// CompilationCacheEntryRef is returned by all the cache Lookup methods below, +// and ensures the underlying proto is not garbage-collected until the client +// discards the ptr. +template class CompilationCacheEntryRef { public: - CompilationCacheEntryRef(); - CompilationCacheEntryRef(TpuCompilationCacheInterface* parent, - CompiledSubgraph* entry, int index); + virtual ~CompilationCacheEntryRef() = default; - virtual ~CompilationCacheEntryRef(); - - // Returns a TpuCompilationCacheEntry that should not be used beyond the - // lifetime of the CompilationCacheEntryRef. - virtual TpuCompilationCacheEntry get(); + // Returns a CompilationCacheEntry that should not be used beyond the lifetime + // of the tpu::CompilationCacheEntryRef. + virtual CacheEntryType get() = 0; // Mutates this ref to point to the entry's subentry (for // sharding/unsharding) or main entry (unchanged) as specified by @@ -72,15 +69,7 @@ class CompilationCacheEntryRef { // // If the requested subentry does not exist, the ref will point to a nullptr // entry, and the original entry will be unref'ed. - virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target); - - protected: - TpuCompilationCacheInterface* parent_; // Not owned. - // A reference to entry_ is acquired in the constructor and released via - // parent->DiscardEntryRefs in the destructor. - CompiledSubgraph* entry_; - // The index of the program in entry_ that is returned by the get method. - int index_; + virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) = 0; }; class TpuCompilationCacheInterface : public ResourceBase { @@ -108,8 +97,7 @@ class TpuCompilationCacheInterface : public ResourceBase { const TpuCompilationCacheKey& subgraph_key, const SessionMetadata* session_metadata, CompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector* proto_key, - std::vector* may_modify_variables, + std::vector* proto_key, std::vector* may_modify_variables, absl::Span* hlo_metadatas, const std::function& compile_function); @@ -136,18 +124,19 @@ class TpuCompilationCacheInterface : public ResourceBase { // Looks up an executable corresponding to the model-parallel core index of // the subgraph represented by key. On success a pointer to an EntryRef // holding the program is returned in entry. - Status Lookup(const std::string& proto_key, - std::unique_ptr* entry); + template + Status Lookup(const string& proto_key, std::unique_ptr* entry); // Looks up an executable corresponding to the model-parallel core index of // the subgraph represented by uid. On success a pointer to an EntryRef // holding the program is returned in entry. + template Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry); + std::unique_ptr* entry); // Looks up the subgraph represented by uid, and returns the vector of keys, // one per core, corresponding to that subgraph. - Status GetKeysFromUid(int64 uid, std::vector* keys); + Status GetKeysFromUid(int64 uid, std::vector* keys); // Makes a reference holder for this cache, that can be stored in the per-step // resource manager and will ensure that compiled entries persist until the @@ -181,7 +170,7 @@ class TpuCompilationCacheInterface : public ResourceBase { // parent_->DiscardEntryRefs. void AddRef(CompiledSubgraph* entry); - std::string DebugString() const override; + string DebugString() const override; private: TpuCompilationCacheInterface* parent_; // Not owned. @@ -196,8 +185,7 @@ class TpuCompilationCacheInterface : public ResourceBase { const TpuCompilationCacheKey& subgraph_key, const SessionMetadata* session_metadata, CompilationRefHolder* per_step_ref_holder, int64* uid, - std::vector* proto_key, - std::vector* may_modify_variables, + std::vector* proto_key, std::vector* may_modify_variables, std::vector* removed_entries, absl::Span* hlo_metadatas, const std::function& compile_function); @@ -242,14 +230,14 @@ class TpuCompilationCacheInterface : public ResourceBase { ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Removes the entry with given key from cache. - size_t RemoveEntry(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Inserts the given key and entry to cache. - void InsertEntry(const std::string& key, CompiledSubgraph* entry) + void InsertEntry(const string& key, CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns the cache key matching given subgraph_key. - std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) + string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates a new entry by running initialize_programs and places it in the @@ -259,7 +247,7 @@ class TpuCompilationCacheInterface : public ResourceBase { // // **InitializeEntry releases mu_ during the call to initialize_programs.** virtual CompiledSubgraph* InitializeEntry( - const std::string& key, + const string& key, const std::function& initialize_programs, const TpuCompilationCacheKey& subgraph_key) @@ -288,16 +276,13 @@ class TpuCompilationCacheInterface : public ResourceBase { // cache_ key matching a given subgraph key. When doing a lookup, check // session_key_map_ first to avoid unnecessay fingerprint computation. // Map from key prefix + session_handle to a cache_ key. - absl::node_hash_map session_key_map_ - ABSL_GUARDED_BY(mu_); + absl::node_hash_map session_key_map_ ABSL_GUARDED_BY(mu_); // Map from key prefix + fingerprint to a cache_ key. - absl::node_hash_map fingerprint_key_map_ - ABSL_GUARDED_BY(mu_); + absl::node_hash_map fingerprint_key_map_ ABSL_GUARDED_BY(mu_); // All the subgraph entries that can be looked up in the cache. An entry is // marked for eviction iff it is present in cache_ and not in // entries_by_last_use_. - std::unordered_map cache_ - ABSL_GUARDED_BY(mu_); + std::unordered_map cache_ ABSL_GUARDED_BY(mu_); // All the subgraph entries that can be looked up in the cache, indexed by // uid. absl::node_hash_map entries_by_uid_ @@ -305,7 +290,7 @@ class TpuCompilationCacheInterface : public ResourceBase { // All the protos that can be looked up in the cache, indexed by proto // key. The value of the map is a subgraph and the index of the proto compiled // for that subgraph. - std::unordered_map> + std::unordered_map> entries_by_proto_key_ ABSL_GUARDED_BY(mu_); // Map from last_use to entry, used to mark entries for eviction in LRU // order. If an entry's last_use counter is not present as a key in @@ -319,6 +304,50 @@ class TpuCompilationCacheInterface : public ResourceBase { TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) = delete; }; + +template +Status TpuCompilationCacheInterface::Lookup( + int64 uid, int proto_index, std::unique_ptr* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme( + "TPU compilation cache proto lookup by uid", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_uid_.find(uid); + if (iter == entries_by_uid_.end()) { + return errors::NotFound("No subgraph found for uid ", uid); + } + CompiledSubgraph* cache_entry = iter->second; + if (proto_index < 0 || + proto_index >= cache_entry->tpu_program_group->program_count()) { + return errors::NotFound("No proto found for core index ", proto_index, + " in subgraph with uid ", uid); + } + *entry = absl::make_unique(this, cache_entry, proto_index); + return Status::OK(); +} + +template +Status TpuCompilationCacheInterface::Lookup( + const string& proto_key, std::unique_ptr* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_proto_key_.find(proto_key); + if (iter == entries_by_proto_key_.end()) { + return errors::NotFound("No proto found for key ", proto_key); + } + CompiledSubgraph* cache_entry = iter->second.first; + int proto_index = iter->second.second; + *entry = absl::make_unique(this, cache_entry, proto_index); + return Status::OK(); +} + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.cc index 29864a310d1..f30a503d2d2 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.cc @@ -16,50 +16,70 @@ limitations under the License. namespace tensorflow { namespace tpu { +namespace { +class CompilationCacheFetchTargetUtility { + public: + CompilationCacheFetchTargetUtility() + : names_({"Invalid", "Main", "Sharding", "Unsharding"}) {} + + std::string name(CompilationCacheFetchTarget target) const { + return names_[static_cast(target)]; + } + + private: + const std::vector names_; +}; + +std::string GetName(CompilationCacheFetchTarget target) { + static const auto* util = new CompilationCacheFetchTargetUtility(); + return util->name(target); +} + +} // namespace TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup( TpuCompilationCacheInterface* cache) - : cache_(cache) { - cache_->Ref(); -} + : cache_(cache) {} TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() { cache_->Unref(); } Status TpuCompilationCacheLocalLookup::Lookup( - const string& proto_key, std::unique_ptr* entry, + const string& proto_key, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) { profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup", /*level=*/2); - Status s = cache_->Lookup(proto_key, entry); + Status s = cache_->Lookup( + proto_key, entry); VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status " << s; if (!s.ok()) { return s; } s = (*entry)->ToSubEntryRef(fetch_target); - VLOG(1) << "Fetched subentry: " - << CompilationCacheFetchTarget_Name(fetch_target) << " with status " + + VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " << s; return s; } Status TpuCompilationCacheLocalLookup::Lookup( int64 uid, int proto_index, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) { profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid", /*level=*/2); - Status s = cache_->Lookup(uid, proto_index, entry); + Status s = cache_->Lookup( + uid, proto_index, entry); VLOG(1) << "Looked up uid " << uid << ", index " << proto_index << " in local subgraph cache status " << s; if (!s.ok()) { return s; } s = (*entry)->ToSubEntryRef(fetch_target); - VLOG(1) << "Fetched subentry: " - << CompilationCacheFetchTarget_Name(fetch_target) << " with status " + VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " << s; return s; } @@ -67,5 +87,6 @@ Status TpuCompilationCacheLocalLookup::Lookup( string TpuCompilationCacheLocalLookup::DebugString() const { return "TpuCompilationCacheLocalLookup"; } + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h index 8db4c11ebea..eb5aadcd3e2 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h @@ -28,17 +28,24 @@ namespace tpu { // Class for looking up TPU programs when the execute and compile Op are in the // same address space. The proto is simply looked up in the compilation cache, // without any serialization taking place. -class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { +class TpuCompilationCacheLocalLookup + : public TpuCompilationCacheLookup< + CompilationCacheEntryRef> { public: + using TpuCompilationCacheEntryRef = + ::tensorflow::tpu::CompilationCacheEntryRef; + using EntryRefImpl = + ::tensorflow::tpu::TpuCompilationCacheExternal::EntryRefImpl; + explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache); ~TpuCompilationCacheLocalLookup() override; Status Lookup(const string& proto_key, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) override; Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) override; string DebugString() const override; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h index ab476322a8a..0d1a53d31d2 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h @@ -23,11 +23,10 @@ limitations under the License. namespace tensorflow { namespace tpu { -// TODO(b/162241759): consider merging TpuCompilationCacheLookup and -// TpuCompilationCacheInterface. // Base class allowing Execute Ops to look up TPU programs. Different subclasses // are used when the execute Op is in the same address space as the compile Op, // and when they need to communicate over RPC. +template class TpuCompilationCacheLookup : public ResourceBase { public: ~TpuCompilationCacheLookup() override = default; @@ -44,11 +43,12 @@ class TpuCompilationCacheLookup : public ResourceBase { // fetch_target requests one of them, then after this call // (*entry)->get().get_executable() will return nullptr. virtual Status Lookup(const string& proto_key, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) = 0; - virtual Status Lookup(const string& proto_key, - std::unique_ptr* entry) { + virtual Status Lookup( + const string& proto_key, + std::unique_ptr* entry) { return Lookup(proto_key, std::move(entry), CompilationCacheFetchTarget::MAIN); } @@ -58,15 +58,17 @@ class TpuCompilationCacheLookup : public ResourceBase { // returned in program. The wrapper is guaranteed to be valid only during the // execution of the Op requesting the proto. virtual Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry, + std::unique_ptr* entry, CompilationCacheFetchTarget fetch_target) = 0; - virtual Status Lookup(int64 uid, int proto_index, - std::unique_ptr* entry) { + virtual Status Lookup( + int64 uid, int proto_index, + std::unique_ptr* entry) { return Lookup(uid, proto_index, std::move(entry), CompilationCacheFetchTarget::MAIN); } }; + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index ce18e844e66..4ed646af302 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -413,6 +413,46 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo( return Status::OK(); } +/* static */ +Status TpuCompileOpKernelCommon::ComputeArgumentShapes( + const tpu::TPUCompileMetadataProto& metadata, + const std::vector& dynamic_shapes, + std::vector* arg_shapes) { + arg_shapes->resize(metadata.args_size()); + int dynamic_shape_pos = 0; + for (int i = 0; i < metadata.args_size(); ++i) { + const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i); + // The XLA compiler determines the shape of each constant by inspecting the + // value of its corresponding host-memory tensor. As a result, we don't need + // to give the compiler graph-inferred shapes for constant arguments. + if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) { + continue; + } + TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape())); + PartialTensorShape static_shape(arg.shape()); + + TensorShape& shape = (*arg_shapes)[i]; + if (static_shape.IsFullyDefined()) { + TF_RET_CHECK(static_shape.AsTensorShape(&shape)); + } else { + TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size()) + << "Too few dynamic shapes"; + shape = dynamic_shapes[dynamic_shape_pos++]; + if (!static_shape.IsCompatibleWith(shape)) { + return errors::InvalidArgument( + "Mismatch between static and dynamic shape for argument. Static " + "shape: ", + static_shape.DebugString(), + "; dynamic shape: ", shape.DebugString()); + } + } + } + // Checks we consumed all of the dynamic shapes. + TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size()) + << "Too many dynamic shapes"; + return Status::OK(); +} + // Function arguments and return values lose their device assignments, so we // must recreate them. /* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals( diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.h b/tensorflow/core/tpu/kernels/tpu_compile_op_common.h index 327aa460ddd..3d3f0afcdb7 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.h @@ -99,6 +99,15 @@ class TpuCompileOpKernelCommon { const std::vector& arg_shapes, TpuProgramGroupInterface* tpu_program_group) = 0; + // Computes shapes for each argument. Uses both the static shape from the + // metadata, and the dynamic shapes where the static shape is not + // defined. There must be one dynamic_shape for each argument with a + // partially defined shape, in index order. + static Status ComputeArgumentShapes( + const tpu::TPUCompileMetadataProto& metadata, + const std::vector& dynamic_shapes, + std::vector* arg_shapes); + // Performs shape inference on `computation`, filling shape_info with operator // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`. static Status RunShapeInferenceOnComputation( diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc index 3440b6d265a..5cc35a07e66 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -540,43 +540,5 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx, } return Status::OK(); } - -Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata, - const std::vector& dynamic_shapes, - std::vector* arg_shapes) { - arg_shapes->resize(metadata.args_size()); - int dynamic_shape_pos = 0; - for (int i = 0; i < metadata.args_size(); ++i) { - const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i); - // The XLA compiler determines the shape of each constant by inspecting the - // value of its corresponding host-memory tensor. As a result, we don't need - // to give the compiler graph-inferred shapes for constant arguments. - if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) { - continue; - } - TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape())); - PartialTensorShape static_shape(arg.shape()); - - TensorShape& shape = (*arg_shapes)[i]; - if (static_shape.IsFullyDefined()) { - TF_RET_CHECK(static_shape.AsTensorShape(&shape)); - } else { - TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size()) - << "Too few dynamic shapes"; - shape = dynamic_shapes[dynamic_shape_pos++]; - if (!static_shape.IsCompatibleWith(shape)) { - return errors::InvalidArgument( - "Mismatch between static and dynamic shape for argument. Static " - "shape: ", - static_shape.DebugString(), - "; dynamic shape: ", shape.DebugString()); - } - } - } - // Checks we consumed all of the dynamic shapes. - TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size()) - << "Too many dynamic shapes"; - return Status::OK(); -} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index ea13d33b521..bc60f64286a 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -159,14 +159,6 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx, TPUCompileMetadataProto* metadata, NameAttrList* function_name, std::string* mlir_module); - -// Computes shapes for each argument. Uses both the static shape from the -// metadata, and the dynamic shapes where the static shape is not -// defined. There must be one dynamic_shape for each argument with a -// partially defined shape, in index order. -Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata, - const std::vector& dynamic_shapes, - std::vector* arg_shapes); } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 5a8c283c7c2..e098dbd682c 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -25,8 +25,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" #include "tensorflow/core/tpu/tpu_api.h" @@ -255,10 +253,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { mesh_state_interface)); } - VLOG(1) << "Removing existing proto compilation cache lookup if it exists"; - OP_REQUIRES_OK(ctx, DeleteIfExists( - rmgr, tpu::kCompiledProtoCacheResourceName)); - if (enable_whole_mesh_compilations_) { // If this is a whole mesh compilation mode, create the compilation cache, // if missing. @@ -282,13 +276,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { if (local_compilation_cache != nullptr) { local_compilation_cache->Unref(); - - tpu::TpuCompilationCacheLookup* proto_lookup; - proto_lookup = - new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache); - OP_REQUIRES_OK( - ctx, rmgr->Create(rmgr->default_container(), - tpu::kCompiledProtoCacheResourceName, proto_lookup)); } Tensor* ctx_output; diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc index 3522ace379a..51c9dd481a3 100644 --- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc @@ -40,12 +40,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" @@ -58,10 +56,14 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_node_context.h" namespace tensorflow { + namespace { -using ::tensorflow::tpu::CompilationCacheEntryRef; -using ::tensorflow::tpu::TpuCompilationCacheLookup; + using ::tensorflow::tpu::TpuNodeContext; +using CompilationCacheEntryRef = ::tensorflow::tpu::CompilationCacheEntryRef< + ::tensorflow::tpu::TpuCompilationCacheEntry>; +using TpuCompilationCacheLookup = + ::tensorflow::tpu::TpuCompilationCacheLookup; // Looks up the input `key` in the compilation cache, populating // `*rendezvous_key_base` and `*entry`. @@ -639,35 +641,28 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) { profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2); string rendezvous_key_base; - std::unique_ptr entry_ref; + std::unique_ptr entry; TF_RETURN_IF_ERROR( - GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref)); + GetComputationCacheEntry(context, &rendezvous_key_base, &entry)); // Shapes of the inputs and outputs, in xla::Shape form. - tpu::TpuCompilationCacheEntry entry = entry_ref->get(); - const tpu::TpuProgramGroup* tpu_program_group = - tensorflow::down_cast( - entry.tpu_program_group()); - CHECK_NE(tpu_program_group, nullptr); - const int core_index = entry.core_index(); - const TPUExecutableInfoProto& executable = - tpu_program_group->executable_info(core_index); + const TPUExecutableInfoProto* proto = entry->get().get_executable_info(); xla::Backend* const backend = node_context->backend(); xla::TransferManager* const transfer_manager = backend->transfer_manager(); TF_RET_CHECK(context->op_device_context()); se::Stream* stream = context->op_device_context()->stream(); - TF_RET_CHECK(executable.input_shapes_size() == 1); + TF_RET_CHECK(proto->input_shapes_size() == 1); - xla::Shape host_shape(executable.input_shapes(0)); + xla::Shape host_shape(proto->input_shapes(0)); TF_ASSIGN_OR_RETURN( auto variable_update_map, - BuildVariableUpdateMap(executable.variable_indices(), + BuildVariableUpdateMap(proto->variable_indices(), fused_device_var_reads_in_computation_inputs_, fused_device_var_updates_in_computation_outputs_, - executable.output_tensor_shapes().size())); + proto->output_tensor_shapes().size())); TF_ASSIGN_OR_RETURN( std::unique_ptr input_buffers, BuildComputationInputs(context, host_shape, variable_update_map, backend, @@ -702,9 +697,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) { // Snapshot the inputs, if a snapshot was requested. std::shared_ptr hlo_snapshot; - if (executable.has_session_module()) { - hlo_snapshot = - std::make_shared(executable.session_module()); + if (proto->has_session_module()) { + hlo_snapshot = std::make_shared(proto->session_module()); auto literal = std::make_shared(shaped_buffer.on_host_shape()); transfer_manager->TransferLiteralFromDevice( @@ -729,9 +723,9 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) { const uint32 rng_seed = GetXLARandomSeed(); std::unique_ptr device_assignment; - if (executable.has_device_assignment()) { + if (proto->has_device_assignment()) { TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize( - executable.device_assignment())); + proto->device_assignment())); } VLOG(4) << "Input buffers after alias resolution: " @@ -749,24 +743,24 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) { // we free a memory and reassign it to other users while a program is running, // all subsequent writes to the program that could possibly clobber the memory // will depend on the program to finish. - const TPUHostTransferInfoProto& host_transfer_info = - tpu_program_group->host_transfer_info(core_index); + const TPUHostTransferInfoProto* host_transfer_info = + entry->get().get_host_transfer_info(); + const xla::HloProto* hlo_metadata = entry->get().get_hlo_metadata(); TF_ASSIGN_OR_RETURN( xla::ExecutionOutput output, - TPUExecute(executable, host_transfer_info, - *tpu_program_group->hlo_metadata(core_index), std::move(input), + TPUExecute(*proto, *host_transfer_info, *hlo_metadata, std::move(input), rendezvous_key_base, rng_seed, node_context.get(), device_assignment.get(), context->cancellation_manager(), context, stream, transfer_stream_ptr.get(), - tpu_program_group->tpu_program(core_index))); + entry->get().get_tpu_program())); stream->ThenRecordEvent(definition_event.get()); TF_ASSIGN_OR_RETURN( std::unique_ptr output_buffers, - AllocateOutputTensors( - context, output.ConsumeResult(), executable.output_tensor_shapes(), - variable_update_map, node_context.get(), stream, device_ordinal, - input_buffers.get(), definition_event)); + AllocateOutputTensors(context, output.ConsumeResult(), + proto->output_tensor_shapes(), variable_update_map, + node_context.get(), stream, device_ordinal, + input_buffers.get(), definition_event)); // Transfer the outputs and save the snapshot to disk. if (hlo_snapshot) { diff --git a/tensorflow/core/tpu/kernels/tpu_program_c_api.h b/tensorflow/core/tpu/kernels/tpu_program_c_api.h index 41c7d47cf97..c9951e4d5ce 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_program_c_api.h @@ -21,9 +21,6 @@ limitations under the License. typedef struct XLA_TpuProgram XLA_TpuProgram; -// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. -enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; - extern "C" { // Creates a new TPU program. @@ -67,15 +64,6 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( const XLA_TpuProgram* tpu_program, bool* may_modify_variables); -// Check if TPU program has sharding. -TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( - const XLA_TpuProgram* tpu_program); - -// Gets TPU program by sharding type. Return value is valid only when the -// `status.status()` returns `OK`. -TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( - XLA_TpuProgram* tpu_program, TpuProgramShardingType type); - struct TfTpu_TpuProgramApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); @@ -88,8 +76,6 @@ struct TfTpu_TpuProgramApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); }; } // extern "C" diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc index 39d1f38b104..e22175af270 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.cc +++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/status_helper.h" @@ -99,71 +98,55 @@ StatusOr> CompileAheadOfTime( compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes, per_core_variable_indices, device_assignment); } -} // namespace -void TpuProgramGroup::Initialize( - absl::Span xla_tpu_programs) { +Status CreateTpuProgramGroup( + absl::Span xla_tpu_programs, + TpuProgramGroupInterface* tpu_program_group_interface) { CHECK_GT(xla_tpu_programs.size(), 0); - set_tpu_programs(xla_tpu_programs); + TpuProgramGroup* tpu_program_group = + tensorflow::down_cast(tpu_program_group_interface); + CHECK_NE(tpu_program_group, nullptr); + tpu_program_group->set_tpu_programs(xla_tpu_programs); - std::vector may_modify_variables_array(xla_tpu_programs.size(), false); - std::vector executable_infos(xla_tpu_programs.size()); - std::vector host_transfer_infos( - xla_tpu_programs.size()); - std::vector hlo_metadatas(xla_tpu_programs.size()); - for (size_t i = 0; i < xla_tpu_programs.size(); ++i) { - const XLA_TpuProgram* xla_tpu_program = xla_tpu_programs[i]; - bool may_modify_variables; - TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn( - xla_tpu_program, &may_modify_variables); - may_modify_variables_array[i] = may_modify_variables; + // TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1. + bool may_modify_variables; + TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_programs[0], + &may_modify_variables); + tpu_program_group->set_may_modify_variables( + std::vector(1, may_modify_variables)); - TpuSerializedProto serialized_executable_info; - TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn( - xla_tpu_program, &serialized_executable_info); - TPUExecutableInfoProto executable_info = - se_tpu::DeserializeProto( - serialized_executable_info); - executable_infos[i] = executable_info; - StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info); + TpuSerializedProto serialized_executable_info; + TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn( + xla_tpu_programs[0], &serialized_executable_info); + TPUExecutableInfoProto executable_info = + se_tpu::DeserializeProto( + serialized_executable_info); + tpu_program_group->set_executable_info(executable_info); + StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info); - TPUHostTransferInfoProto host_transfer_info; - TpuSerializedProto serialized_host_transfer_info; - TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn( - xla_tpu_program, &serialized_host_transfer_info); - if (serialized_host_transfer_info.size > 0) { - host_transfer_info = se_tpu::DeserializeProto( - serialized_host_transfer_info); - StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info); - } - host_transfer_infos[i] = host_transfer_info; - - TpuSerializedProto serialized_hlo_metadata; - TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_program, - &serialized_hlo_metadata); - xla::HloProto hlo_metadata = - se_tpu::DeserializeProto(serialized_hlo_metadata); - hlo_metadatas[i] = hlo_metadata; - StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata); + TPUHostTransferInfoProto host_transfer_info; + TpuSerializedProto serialized_host_transfer_info; + TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn( + xla_tpu_programs[0], &serialized_host_transfer_info); + if (serialized_host_transfer_info.size > 0) { + host_transfer_info = se_tpu::DeserializeProto( + serialized_host_transfer_info); + StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info); } + tpu_program_group->set_host_transfer_info(host_transfer_info); - may_modify_variables_ = may_modify_variables_array; - executable_infos_ = executable_infos; - host_transfer_infos_ = host_transfer_infos; - hlo_metadatas_ = hlo_metadatas; - RefreshHloMetadatasPtrs(); + TpuSerializedProto serialized_hlo_metadata; + TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(xla_tpu_programs[0], + &serialized_hlo_metadata); + xla::HloProto hlo_metadata = + se_tpu::DeserializeProto(serialized_hlo_metadata); + tpu_program_group->set_hlo_metadata(hlo_metadata); + StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata); + + return Status::OK(); } -bool TpuProgramGroup::has_sharding_program() const { - for (const XLA_TpuProgram* tpu_program : tpu_programs_) { - if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) { - return false; - } - } - return true; -} - -size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); } +} // namespace int64_t TpuProgramGroup::program_size() const { int64_t total_size = 0; @@ -218,6 +201,12 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() { TF_RET_CHECK(per_core_output_shapes.size() == per_core_variable_indices.size()); + // TODO(henrytan): add an interface to TpuProgramGroupInterface to set + // may_modify_variables. + TpuProgramGroup* tpu_program_group = + tensorflow::down_cast(tpu_program_group_interface); + tpu_program_group->may_modify_variables_ = may_modify_variables; + // With shardable input/output pairs, XLA could generate separate // sharding/unsharding programs along with the main program. The // sharding/unsharding programs will be in nested entries of the AOT @@ -232,20 +221,17 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() { TF_RET_CHECK(xla_tpu_programs.size() == 1 || xla_tpu_programs.size() == metadata.num_cores_per_replica()); - // TODO(henrytan): add an interface to TpuProgramGroupInterface to set - // may_modify_variables. - TpuProgramGroup* tpu_program_group = - tensorflow::down_cast(tpu_program_group_interface); - tpu_program_group->Initialize(xla_tpu_programs); - tpu_program_group->may_modify_variables_ = may_modify_variables; + TF_RETURN_IF_ERROR( + CreateTpuProgramGroup(xla_tpu_programs, tpu_program_group)); return Status::OK(); } TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other) : may_modify_variables_(std::move(other.may_modify_variables_)), + host_compute_metadata_(std::move(other.host_compute_metadata_)), tpu_programs_(std::move(other.tpu_programs_)), - executable_infos_(std::move(other.executable_infos_)), - host_transfer_infos_(std::move(other.host_transfer_infos_)), + executable_info_(std::move(other.executable_info_)), + host_transfer_info_(std::move(other.host_transfer_info_)), hlo_metadatas_(std::move(other.hlo_metadatas_)) { RefreshHloMetadatasPtrs(); } @@ -262,12 +248,6 @@ absl::Span TpuProgramGroup::hlo_metadatas() const { return hlo_metadatas_ptrs_; } -const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const { - CHECK_GE(index, 0); - CHECK_LT(index, hlo_metadatas_ptrs_.size()); - return hlo_metadatas_ptrs_[index]; -} - void TpuProgramGroup::RefreshHloMetadatasPtrs() { hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size()); for (const auto& hlo_metadata_internal_ : hlo_metadatas_) { @@ -282,47 +262,6 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key, return Status::OK(); } -const std::vector& TpuProgramGroup::may_modify_variables() const { - return may_modify_variables_; -} - -void TpuProgramGroup::set_may_modify_variables( - const std::vector& may_modify_variables) { - may_modify_variables_ = may_modify_variables; -} - -const std::vector& TpuProgramGroup::tpu_programs() const { - return tpu_programs_; -} - -const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const { - CHECK_GE(index, 0); - CHECK_LT(index, tpu_programs_.size()); - return tpu_programs_[index]; -} - -void TpuProgramGroup::set_tpu_programs( - absl::Span tpu_programs) { - tpu_programs_.resize(tpu_programs.size()); - for (size_t i = 0; i < tpu_programs.size(); ++i) { - tpu_programs_[i] = tpu_programs[i]; - } -} - -const TPUExecutableInfoProto& TpuProgramGroup::executable_info( - int index) const { - CHECK_GE(index, 0); - CHECK_LT(index, executable_infos_.size()); - return executable_infos_[index]; -} - -const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info( - int index) const { - CHECK_GE(index, 0); - CHECK_LT(index, host_transfer_infos_.size()); - return host_transfer_infos_[index]; -} - /*static*/ Status TpuProgramGroup::CompileAndBuild( const TpuCompilationRequestProto& compilation_request, @@ -348,27 +287,15 @@ Status TpuProgramGroup::CompileAndBuild( TF_RET_CHECK(count == 1 || count == compilation_request.metadata().num_cores_per_replica()); - VLOG(1) << "Initialize TpuProgramGroup."; - TpuProgramGroup* tpu_program_group = - tensorflow::down_cast(tpu_program_group_interface); - tpu_program_group->Initialize( - absl::MakeConstSpan(&xla_tpu_programs[0], count)); + VLOG(1) << "CreateTpuProgramGroup"; + Status serialize_status = + CreateTpuProgramGroup(absl::MakeConstSpan(&xla_tpu_programs[0], count), + tpu_program_group_interface); + VLOG(1) << absl::StrCat("Run CreateTpuProgramGroup completed. StatusCode: ", + serialize_status.code()); TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs); - return status.status(); + return serialize_status; } -std::vector TpuProgramGroup::tpu_programs( - TpuProgramShardingType sharding_type) const { - std::vector tpu_programs; - tpu_programs.reserve(tpu_programs_.size()); - for (size_t i = 0; i < tpu_programs_.size(); ++i) { - if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) { - tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn( - tpu_programs_[i], sharding_type)); - CHECK_NE(tpu_programs[i], nullptr); - } - } - return tpu_programs; -} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h index b76ef3d507a..4bc8cdd003a 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -102,16 +102,11 @@ class TpuProgramGroup : public TpuProgramGroupInterface { const absl::optional& xla_device_assignment, TpuProgramGroupInterface* tpu_program_group_interface); - // Initializes `TpuProgramGroup` object with `xla_tpu_programs`. - void Initialize(absl::Span xla_tpu_programs); - TpuProgramGroup() = default; TpuProgramGroup(TpuProgramGroup&& other); TpuProgramGroup& operator=(TpuProgramGroup&&) = delete; - bool has_sharding_program() const override; - - size_t program_count() const override; + size_t program_count() const override { return tpu_programs_.size(); } int64_t program_size() const override; @@ -122,29 +117,58 @@ class TpuProgramGroup : public TpuProgramGroupInterface { Status LogCompilationStats(const TpuCompilationCacheKey& key, absl::Duration duration) override; - const std::vector& may_modify_variables() const override; - void set_may_modify_variables(const std::vector& may_modify_variables); + const std::vector& may_modify_variables() const override { + return may_modify_variables_; + } + void set_may_modify_variables(const std::vector& may_modify_variables) { + may_modify_variables_ = may_modify_variables; + } - const std::vector& tpu_programs() const; - std::vector tpu_programs(TpuProgramShardingType type) const; - const XLA_TpuProgram* tpu_program(int index) const; - void set_tpu_programs(absl::Span tpu_programs); + const tf2xla::HostComputeMetadata& host_compute_metadata() const { + return host_compute_metadata_; + } + void set_host_compute_metadata( + const tf2xla::HostComputeMetadata& host_compute_metadata) { + host_compute_metadata_ = host_compute_metadata; + } - const TPUExecutableInfoProto& executable_info(int index) const; + const std::vector& tpu_programs() const { + return tpu_programs_; + } + void set_tpu_programs(absl::Span tpu_programs) { + tpu_programs_.resize(tpu_programs.size()); + for (size_t i = 0; i < tpu_programs.size(); ++i) { + tpu_programs_[i] = tpu_programs[i]; + } + } + + const TPUExecutableInfoProto& executable_info() const { + return executable_info_; + } + void set_executable_info(const TPUExecutableInfoProto& executable_info) { + executable_info_ = executable_info; + } + + const TPUHostTransferInfoProto& host_transfer_info() const { + return host_transfer_info_; + } + void set_host_transfer_info( + const TPUHostTransferInfoProto& host_transfer_info) { + host_transfer_info_ = host_transfer_info; + } - const TPUHostTransferInfoProto& host_transfer_info(int index) const; void set_hlo_metadata(const xla::HloProto& hlo_metadata); - const xla::HloProto* hlo_metadata(int index) const; absl::Span hlo_metadatas() const override; private: void RefreshHloMetadatasPtrs(); std::vector may_modify_variables_; + tf2xla::HostComputeMetadata host_compute_metadata_; std::vector tpu_programs_; // Not owned. - std::vector executable_infos_; - std::vector host_transfer_infos_; + TPUExecutableInfoProto executable_info_; + TPUHostTransferInfoProto host_transfer_info_; // To be consistent with the TpuProgramGroupInterface::hlo_metadatas() // signature, we store HloProto values in hlo_metadatas_ when diff --git a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h index 4af94f8e1ad..cb7347783b1 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h @@ -20,8 +20,6 @@ limitations under the License. #include #include -#include "absl/time/time.h" -#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -36,16 +34,13 @@ class TpuProgramGroupInterface { public: virtual ~TpuProgramGroupInterface() = default; - // Check if whether sharding/unsharding program exists. - virtual bool has_sharding_program() const = 0; - // Computes program count. virtual size_t program_count() const = 0; // Computes total program size. virtual int64_t program_size() const = 0; - // Unloads and destroys safely TPU programs. + // Unloads and destroys safely Tpu programs. virtual void UnloadAndDestroyPrograms() = 0; // Logs program memory summary. diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 6914a8cd102..682cc8b1c13 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -64,8 +64,6 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) { TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo); TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata); TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram); return tensorflow::Status::OK(); }