Introduce TpuCompilationCache create function registration

PiperOrigin-RevId: 321394258
Change-Id: I61aa4c39762b078ae4349f4b456a258ac1d13bde
This commit is contained in:
Frank Chen 2020-07-15 10:54:36 -07:00 committed by TensorFlower Gardener
parent 37e9ec4b3b
commit dfb9a633cf
7 changed files with 139 additions and 5 deletions

View File

@ -79,7 +79,10 @@ tf_kernel_library(
srcs = ["tpu_configuration_ops.cc"],
hdrs = ["tpu_configuration_ops.h"],
deps = [
":tpu_compilation_cache_factory",
":tpu_compilation_cache_interface",
":tpu_mesh_state_interface",
":tpu_op_consts",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/xla:util",
@ -133,6 +136,20 @@ tf_proto_library_cc(
],
)
cc_library(
name = "tpu_compilation_cache_factory",
srcs = ["tpu_compilation_cache_factory.cc"],
hdrs = ["tpu_compilation_cache_factory.h"],
deps = [
":tpu_compilation_cache_external",
":tpu_compilation_cache_interface",
":tpu_op_consts",
"//tensorflow/core:framework",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "tpu_compilation_cache_key",
srcs = [],
@ -323,7 +340,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/platform:casts", # buildcleaner: keep
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/base:core_headers",

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
namespace tensorflow {
namespace tpu {
namespace {
TpuCompilationCacheInterface* CreateCompilationCacheExternal() {
// NOTE: Change the 1 << 33 value to change the compilation cache size.
// TODO(frankchn): Make this configurable.
return new TpuCompilationCacheExternal(int64{1} << 33); // 8 GB
}
// Using a pointer here to fulfill the trivially destructible requirement for
// static variables.
static std::function<TpuCompilationCacheInterface*()>*
compilation_cache_creation_fn =
new std::function<TpuCompilationCacheInterface*()>(
CreateCompilationCacheExternal);
} // namespace
std::function<TpuCompilationCacheInterface*()> GetCompilationCacheCreateFn() {
return *compilation_cache_creation_fn;
}
void SetCompilationCacheCreateFn(
std::function<TpuCompilationCacheInterface*()> fn) {
delete compilation_cache_creation_fn;
compilation_cache_creation_fn =
new std::function<TpuCompilationCacheInterface*()>(fn);
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
#include <functional>
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
namespace tensorflow {
namespace tpu {
std::function<TpuCompilationCacheInterface*()> GetCompilationCacheCreateFn();
void SetCompilationCacheCreateFn(
std::function<TpuCompilationCacheInterface*()> fn);
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"

View File

@ -23,7 +23,10 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
@ -67,6 +70,16 @@ Status DeleteIfExists(ResourceMgr* resource_manager,
} // namespace
Status CreateTpuCompilationCache(
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
rmgr->default_container(), tpu::kCompilationCacheResourceName,
compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
*new_cache = tpu::GetCompilationCacheCreateFn()();
return Status::OK();
});
}
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "ConfigureDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
@ -98,9 +111,15 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
// Create the subgraph compilation cache and put it in the local resource
// manager.
tpu::TpuCompilationCacheInterface* compilation_cache;
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
core::ScopedUnref compilation_cache_ref(compilation_cache);
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
num_devices_per_host.size(), num_devices_per_host.data(),
&host_config_output_size, &host_config_output, status);
compilation_cache, &host_config_output_size, &host_config_output, status);
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
OP_REQUIRES_OK(
@ -230,6 +249,14 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
mesh_state_interface));
}
if (enable_whole_mesh_compilations_) {
// If this is a whole mesh compilation mode, create the compilation cache,
// if missing.
tpu::TpuCompilationCacheInterface* compilation_cache;
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
compilation_cache->Unref();
}
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
tpu_host_config.size(), tpu_host_config.data(),
enable_whole_mesh_compilations_, &device_id_output_size,

View File

@ -16,9 +16,13 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
namespace tensorflow {
Status CreateTpuCompilationCache(
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
// The ConfigureDistributedTpu op is used to start an TPUDriver from
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
// connection host:port for the CompilationCacheServer. The

View File

@ -35,8 +35,8 @@ extern "C" {
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
size_t* host_config_output_size, char** host_config_output,
TF_Status* status);
void* tpu_compilation_cache_interface, size_t* host_config_output_size,
char** host_config_output, TF_Status* status);
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,