Refactor TpuCompilationCacheEntry interface to return TpuProgramGroupInterface and core_index and makes CacheEntry less transparent and move application specific logics outside of cache.

PiperOrigin-RevId: 323651431
Change-Id: Ia790cf3bc5b17fe9647ac93b960357cf48868efd
This commit is contained in:
Henry Tan 2020-07-28 14:16:41 -07:00 committed by TensorFlower Gardener
parent 916e0023b9
commit 711e05bd78
3 changed files with 15 additions and 72 deletions

View File

@ -209,19 +209,14 @@ cc_library(
cc_library(
name = "tpu_compilation_cache_entry",
srcs = ["tpu_compilation_cache_entry.cc"],
hdrs = [
"tpu_compilation_cache_entry.h",
],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_proto_cc",
":tpu_executable_info_proto_cc",
":tpu_program_group",
":tpu_program_group_interface",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/platform:casts",
],
)

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
namespace tpu {
TpuCompilationCacheEntry::TpuCompilationCacheEntry(
const TpuProgramGroupInterface* tpu_program_group, int core_index)
: tpu_program_group_(
tensorflow::down_cast<const TpuProgramGroup*>(tpu_program_group)),
core_index_(core_index) {}
// Constructor for an empty entry.
TpuCompilationCacheEntry::TpuCompilationCacheEntry()
: tpu_program_group_(nullptr) {}
const TPUExecutableInfoProto* TpuCompilationCacheEntry::get_executable_info()
const {
return &(tpu_program_group_->executable_info());
}
const TPUHostTransferInfoProto*
TpuCompilationCacheEntry::get_host_transfer_info() const {
return &(tpu_program_group_->host_transfer_info());
}
const xla::HloProto* TpuCompilationCacheEntry::get_hlo_metadata() const {
return tpu_program_group_->hlo_metadatas()[core_index_];
}
// TODO(henrytan,jiawenhao): When should we expect more than one
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
const XLA_TpuProgram* TpuCompilationCacheEntry::get_tpu_program() const {
CHECK_EQ(tpu_program_group_->program_count(), 1);
return tpu_program_group_->tpu_programs()[core_index_];
}
} // namespace tpu
} // namespace tensorflow

View File

@ -18,30 +18,32 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
namespace tensorflow {
namespace tpu {
// A version of `CompilationCacheEntry` to access Tpu binary program
// `XLA_TpuProgram`.
// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to
// fetch a TPU program for a given TPU core index.
class TpuCompilationCacheEntry {
public:
explicit TpuCompilationCacheEntry(
const TpuProgramGroupInterface* tpu_program_group, int core_index);
const TpuProgramGroupInterface* tpu_program_group, int core_index)
: tpu_program_group_(tpu_program_group), core_index_(core_index) {}
// Constructor for an empty entry.
TpuCompilationCacheEntry();
const TPUExecutableInfoProto* get_executable_info() const;
const TPUHostTransferInfoProto* get_host_transfer_info() const;
const xla::HloProto* get_hlo_metadata() const;
// TODO(henrytan): maybe nicer to return C++ wrapper of `XLA_TpuProgram`
const XLA_TpuProgram* get_tpu_program() const;
TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {}
const TpuProgramGroupInterface* tpu_program_group() const {
return tpu_program_group_;
}
int core_index() const { return core_index_; }
private:
const TpuProgramGroup* tpu_program_group_;
const TpuProgramGroupInterface* tpu_program_group_;
int core_index_;
};
} // namespace tpu
} // namespace tensorflow