Introduce TpuCompilationCache create function registration
PiperOrigin-RevId: 321394258 Change-Id: I61aa4c39762b078ae4349f4b456a258ac1d13bde
This commit is contained in:
parent
37e9ec4b3b
commit
dfb9a633cf
@ -79,7 +79,10 @@ tf_kernel_library(
|
||||
srcs = ["tpu_configuration_ops.cc"],
|
||||
hdrs = ["tpu_configuration_ops.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_factory",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_op_consts",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -133,6 +136,20 @@ tf_proto_library_cc(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_factory",
|
||||
srcs = ["tpu_compilation_cache_factory.cc"],
|
||||
hdrs = ["tpu_compilation_cache_factory.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_external",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_op_consts",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_key",
|
||||
srcs = [],
|
||||
@ -323,7 +340,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
|
55
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc
Normal file
55
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace {
|
||||
|
||||
TpuCompilationCacheInterface* CreateCompilationCacheExternal() {
|
||||
// NOTE: Change the 1 << 33 value to change the compilation cache size.
|
||||
// TODO(frankchn): Make this configurable.
|
||||
return new TpuCompilationCacheExternal(int64{1} << 33); // 8 GB
|
||||
}
|
||||
|
||||
// Using a pointer here to fulfill the trivially destructible requirement for
|
||||
// static variables.
|
||||
static std::function<TpuCompilationCacheInterface*()>*
|
||||
compilation_cache_creation_fn =
|
||||
new std::function<TpuCompilationCacheInterface*()>(
|
||||
CreateCompilationCacheExternal);
|
||||
|
||||
} // namespace
|
||||
|
||||
std::function<TpuCompilationCacheInterface*()> GetCompilationCacheCreateFn() {
|
||||
return *compilation_cache_creation_fn;
|
||||
}
|
||||
|
||||
void SetCompilationCacheCreateFn(
|
||||
std::function<TpuCompilationCacheInterface*()> fn) {
|
||||
delete compilation_cache_creation_fn;
|
||||
compilation_cache_creation_fn =
|
||||
new std::function<TpuCompilationCacheInterface*()>(fn);
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
33
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h
Normal file
33
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
std::function<TpuCompilationCacheInterface*()> GetCompilationCacheCreateFn();
|
||||
|
||||
void SetCompilationCacheCreateFn(
|
||||
std::function<TpuCompilationCacheInterface*()> fn);
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
@ -23,7 +23,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
@ -67,6 +70,16 @@ Status DeleteIfExists(ResourceMgr* resource_manager,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status CreateTpuCompilationCache(
|
||||
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
|
||||
return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
|
||||
rmgr->default_container(), tpu::kCompilationCacheResourceName,
|
||||
compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
|
||||
*new_cache = tpu::GetCompilationCacheCreateFn()();
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
||||
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "ConfigureDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
|
||||
@ -98,9 +111,15 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||
|
||||
// Create the subgraph compilation cache and put it in the local resource
|
||||
// manager.
|
||||
tpu::TpuCompilationCacheInterface* compilation_cache;
|
||||
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
|
||||
core::ScopedUnref compilation_cache_ref(compilation_cache);
|
||||
|
||||
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
&host_config_output_size, &host_config_output, status);
|
||||
compilation_cache, &host_config_output_size, &host_config_output, status);
|
||||
|
||||
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
|
||||
OP_REQUIRES_OK(
|
||||
@ -230,6 +249,14 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
mesh_state_interface));
|
||||
}
|
||||
|
||||
if (enable_whole_mesh_compilations_) {
|
||||
// If this is a whole mesh compilation mode, create the compilation cache,
|
||||
// if missing.
|
||||
tpu::TpuCompilationCacheInterface* compilation_cache;
|
||||
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
|
||||
compilation_cache->Unref();
|
||||
}
|
||||
|
||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||
|
@ -16,9 +16,13 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status CreateTpuCompilationCache(
|
||||
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
|
||||
|
||||
// The ConfigureDistributedTpu op is used to start an TPUDriver from
|
||||
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
|
||||
// connection host:port for the CompilationCacheServer. The
|
||||
|
@ -35,8 +35,8 @@ extern "C" {
|
||||
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
size_t* host_config_output_size, char** host_config_output,
|
||||
TF_Status* status);
|
||||
void* tpu_compilation_cache_interface, size_t* host_config_output_size,
|
||||
char** host_config_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
|
Loading…
Reference in New Issue
Block a user