Refactor TpuCompilationCacheEntry interface to return TpuProgramGroupInterface and core_index and makes CacheEntry less transparent and move application specific logics outside of cache.
PiperOrigin-RevId: 323651431 Change-Id: Ia790cf3bc5b17fe9647ac93b960357cf48868efd
This commit is contained in:
parent
916e0023b9
commit
711e05bd78
@ -209,19 +209,14 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_entry",
|
||||
srcs = ["tpu_compilation_cache_entry.cc"],
|
||||
hdrs = [
|
||||
"tpu_compilation_cache_entry.h",
|
||||
],
|
||||
deps = [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_executable_info_proto_cc",
|
||||
":tpu_program_group",
|
||||
":tpu_program_group_interface",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/platform:casts",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
TpuCompilationCacheEntry::TpuCompilationCacheEntry(
|
||||
const TpuProgramGroupInterface* tpu_program_group, int core_index)
|
||||
: tpu_program_group_(
|
||||
tensorflow::down_cast<const TpuProgramGroup*>(tpu_program_group)),
|
||||
core_index_(core_index) {}
|
||||
|
||||
// Constructor for an empty entry.
|
||||
TpuCompilationCacheEntry::TpuCompilationCacheEntry()
|
||||
: tpu_program_group_(nullptr) {}
|
||||
|
||||
const TPUExecutableInfoProto* TpuCompilationCacheEntry::get_executable_info()
|
||||
const {
|
||||
return &(tpu_program_group_->executable_info());
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto*
|
||||
TpuCompilationCacheEntry::get_host_transfer_info() const {
|
||||
return &(tpu_program_group_->host_transfer_info());
|
||||
}
|
||||
|
||||
const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const {
|
||||
return tpu_program_group_->hlo_metadatas()[core_index_];
|
||||
}
|
||||
|
||||
// TODO(henrytan,jiawenhao): When should we expect more than one
|
||||
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
|
||||
const XLA_TpuProgram* TpuCompilationCacheEntry::get_tpu_program() const {
|
||||
CHECK_EQ(tpu_program_group_->program_count(), 1);
|
||||
return tpu_program_group_->tpu_programs()[core_index_];
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
@ -18,30 +18,32 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// A version of `CompilationCacheEntry` to access Tpu binary program
|
||||
// `XLA_TpuProgram`.
|
||||
// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to
|
||||
// fetch a TPU program for a given TPU core index.
|
||||
class TpuCompilationCacheEntry {
|
||||
public:
|
||||
explicit TpuCompilationCacheEntry(
|
||||
const TpuProgramGroupInterface* tpu_program_group, int core_index);
|
||||
const TpuProgramGroupInterface* tpu_program_group, int core_index)
|
||||
: tpu_program_group_(tpu_program_group), core_index_(core_index) {}
|
||||
|
||||
// Constructor for an empty entry.
|
||||
TpuCompilationCacheEntry();
|
||||
const TPUExecutableInfoProto* get_executable_info() const;
|
||||
const TPUHostTransferInfoProto* get_host_transfer_info() const;
|
||||
const xla::HloProto* get_hlo_metadata() const;
|
||||
// TODO(henrytan): maybe nicer to return C++ wrapper of `XLA_TpuProgram`
|
||||
const XLA_TpuProgram* get_tpu_program() const;
|
||||
TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {}
|
||||
|
||||
const TpuProgramGroupInterface* tpu_program_group() const {
|
||||
return tpu_program_group_;
|
||||
}
|
||||
|
||||
int core_index() const { return core_index_; }
|
||||
|
||||
private:
|
||||
const TpuProgramGroup* tpu_program_group_;
|
||||
const TpuProgramGroupInterface* tpu_program_group_;
|
||||
int core_index_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user