Introduce TpuCompilationCache create function registration
PiperOrigin-RevId: 321394258 Change-Id: I61aa4c39762b078ae4349f4b456a258ac1d13bde
This commit is contained in:
parent
37e9ec4b3b
commit
dfb9a633cf
@ -79,7 +79,10 @@ tf_kernel_library(
|
|||||||
srcs = ["tpu_configuration_ops.cc"],
|
srcs = ["tpu_configuration_ops.cc"],
|
||||||
hdrs = ["tpu_configuration_ops.h"],
|
hdrs = ["tpu_configuration_ops.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":tpu_compilation_cache_factory",
|
||||||
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_mesh_state_interface",
|
":tpu_mesh_state_interface",
|
||||||
|
":tpu_op_consts",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
@ -133,6 +136,20 @@ tf_proto_library_cc(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_compilation_cache_factory",
|
||||||
|
srcs = ["tpu_compilation_cache_factory.cc"],
|
||||||
|
hdrs = ["tpu_compilation_cache_factory.h"],
|
||||||
|
deps = [
|
||||||
|
":tpu_compilation_cache_external",
|
||||||
|
":tpu_compilation_cache_interface",
|
||||||
|
":tpu_op_consts",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_cache_key",
|
name = "tpu_compilation_cache_key",
|
||||||
srcs = [],
|
srcs = [],
|
||||||
@ -323,7 +340,6 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
|
||||||
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
"//tensorflow/core/platform:casts", # buildcleaner: keep
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
55
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc
Normal file
55
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TpuCompilationCacheInterface* CreateCompilationCacheExternal() {
|
||||||
|
// NOTE: Change the 1 << 33 value to change the compilation cache size.
|
||||||
|
// TODO(frankchn): Make this configurable.
|
||||||
|
return new TpuCompilationCacheExternal(int64{1} << 33); // 8 GB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Using a pointer here to fulfill the trivially destructible requirement for
|
||||||
|
// static variables.
|
||||||
|
static std::function<TpuCompilationCacheInterface*()>*
|
||||||
|
compilation_cache_creation_fn =
|
||||||
|
new std::function<TpuCompilationCacheInterface*()>(
|
||||||
|
CreateCompilationCacheExternal);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::function<TpuCompilationCacheInterface*()> GetCompilationCacheCreateFn() {
|
||||||
|
return *compilation_cache_creation_fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCompilationCacheCreateFn(
|
||||||
|
std::function<TpuCompilationCacheInterface*()> fn) {
|
||||||
|
delete compilation_cache_creation_fn;
|
||||||
|
compilation_cache_creation_fn =
|
||||||
|
new std::function<TpuCompilationCacheInterface*()>(fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
33
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h
Normal file
33
tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
|
std::function<TpuCompilationCacheInterface*()> GetCompilationCacheCreateFn();
|
||||||
|
|
||||||
|
void SetCompilationCacheCreateFn(
|
||||||
|
std::function<TpuCompilationCacheInterface*()> fn);
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_
|
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
@ -23,7 +23,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/platform/refcount.h"
|
#include "tensorflow/core/platform/refcount.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||||
@ -67,6 +70,16 @@ Status DeleteIfExists(ResourceMgr* resource_manager,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
Status CreateTpuCompilationCache(
|
||||||
|
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
|
||||||
|
return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
|
||||||
|
rmgr->default_container(), tpu::kCompilationCacheResourceName,
|
||||||
|
compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
|
||||||
|
*new_cache = tpu::GetCompilationCacheCreateFn()();
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
VLOG(1) << "ConfigureDistributedTpuOp";
|
VLOG(1) << "ConfigureDistributedTpuOp";
|
||||||
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
|
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
|
||||||
@ -98,9 +111,15 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||||
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||||
|
|
||||||
|
// Create the subgraph compilation cache and put it in the local resource
|
||||||
|
// manager.
|
||||||
|
tpu::TpuCompilationCacheInterface* compilation_cache;
|
||||||
|
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
|
||||||
|
core::ScopedUnref compilation_cache_ref(compilation_cache);
|
||||||
|
|
||||||
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||||
&host_config_output_size, &host_config_output, status);
|
compilation_cache, &host_config_output_size, &host_config_output, status);
|
||||||
|
|
||||||
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
|
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
@ -230,6 +249,14 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
mesh_state_interface));
|
mesh_state_interface));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enable_whole_mesh_compilations_) {
|
||||||
|
// If this is a whole mesh compilation mode, create the compilation cache,
|
||||||
|
// if missing.
|
||||||
|
tpu::TpuCompilationCacheInterface* compilation_cache;
|
||||||
|
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
|
||||||
|
compilation_cache->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||||
tpu_host_config.size(), tpu_host_config.data(),
|
tpu_host_config.size(), tpu_host_config.data(),
|
||||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||||
|
@ -16,9 +16,13 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status CreateTpuCompilationCache(
|
||||||
|
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
|
||||||
|
|
||||||
// The ConfigureDistributedTpu op is used to start an TPUDriver from
|
// The ConfigureDistributedTpu op is used to start an TPUDriver from
|
||||||
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
|
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
|
||||||
// connection host:port for the CompilationCacheServer. The
|
// connection host:port for the CompilationCacheServer. The
|
||||||
|
@ -35,8 +35,8 @@ extern "C" {
|
|||||||
|
|
||||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||||
size_t* host_config_output_size, char** host_config_output,
|
void* tpu_compilation_cache_interface, size_t* host_config_output_size,
|
||||||
TF_Status* status);
|
char** host_config_output, TF_Status* status);
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||||
const size_t num_hosts, const size_t num_cores_per_host,
|
const size_t num_hosts, const size_t num_cores_per_host,
|
||||||
|
Loading…
Reference in New Issue
Block a user