diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 6ff0fb1df73..89a36ed9ae4 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -79,7 +79,10 @@ tf_kernel_library( srcs = ["tpu_configuration_ops.cc"], hdrs = ["tpu_configuration_ops.h"], deps = [ + ":tpu_compilation_cache_factory", + ":tpu_compilation_cache_interface", ":tpu_mesh_state_interface", + ":tpu_op_consts", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/compiler/xla:util", @@ -133,6 +136,20 @@ tf_proto_library_cc( ], ) +cc_library( + name = "tpu_compilation_cache_factory", + srcs = ["tpu_compilation_cache_factory.cc"], + hdrs = ["tpu_compilation_cache_factory.h"], + deps = [ + ":tpu_compilation_cache_external", + ":tpu_compilation_cache_interface", + ":tpu_op_consts", + "//tensorflow/core:framework", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "tpu_compilation_cache_key", srcs = [], @@ -323,7 +340,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/platform:casts", # buildcleaner: keep "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/base:core_headers", diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc new file mode 100644 index 00000000000..86469ae7ebb --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h" + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_op_consts.h" + +namespace tensorflow { +namespace tpu { +namespace { + +TpuCompilationCacheInterface* CreateCompilationCacheExternal() { + // NOTE: Change the 1 << 33 value to change the compilation cache size. + // TODO(frankchn): Make this configurable. + return new TpuCompilationCacheExternal(int64{1} << 33); // 8 GB +} + +// Using a pointer here to fulfill the trivially destructible requirement for +// static variables. +static std::function* + compilation_cache_creation_fn = + new std::function( + CreateCompilationCacheExternal); + +} // namespace + +std::function GetCompilationCacheCreateFn() { + return *compilation_cache_creation_fn; +} + +void SetCompilationCacheCreateFn( + std::function fn) { + delete compilation_cache_creation_fn; + compilation_cache_creation_fn = + new std::function(fn); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h new file mode 100644 index 00000000000..4710f916c48 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_ + +#include + +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" + +namespace tensorflow { +namespace tpu { + +std::function GetCompilationCacheCreateFn(); + +void SetCompilationCacheCreateFn( + std::function fn); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h index 9726d5b78b9..cde6467b7af 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h @@ -25,7 +25,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 065a7f77dd6..13efdc46e10 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -23,7 +23,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_op_consts.h" #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/core/tpu/tpu_config_c_api.h" #include "tensorflow/core/tpu/tpu_configuration.h" @@ -67,6 +70,16 @@ Status DeleteIfExists(ResourceMgr* resource_manager, } // namespace +Status CreateTpuCompilationCache( + ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) { + return rmgr->LookupOrCreate( + rmgr->default_container(), tpu::kCompilationCacheResourceName, + compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) { + *new_cache = tpu::GetCompilationCacheCreateFn()(); + return Status::OK(); + }); +} + void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { VLOG(1) << "ConfigureDistributedTpuOp"; XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp"); @@ -98,9 +111,15 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, DeleteIfExists( rmgr, tpu::kTpuMeshStateInterfaceResourceName)); + // Create the subgraph compilation cache and put it in the local resource + // manager. + tpu::TpuCompilationCacheInterface* compilation_cache; + OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache)); + core::ScopedUnref compilation_cache_ref(compilation_cache); + tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn( num_devices_per_host.size(), num_devices_per_host.data(), - &host_config_output_size, &host_config_output, status); + compilation_cache, &host_config_output_size, &host_config_output, status); auto* tpu_mesh = tpu::TpuMeshStateInterface::Create(); OP_REQUIRES_OK( @@ -230,6 +249,14 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { mesh_state_interface)); } + if (enable_whole_mesh_compilations_) { + // If this is a whole mesh compilation mode, create the compilation cache, + // if missing. + tpu::TpuCompilationCacheInterface* compilation_cache; + OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache)); + compilation_cache->Unref(); + } + tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn( tpu_host_config.size(), tpu_host_config.data(), enable_whole_mesh_compilations_, &device_id_output_size, diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h index f75a47e5aaf..d0bf5809842 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h @@ -16,9 +16,13 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" namespace tensorflow { +Status CreateTpuCompilationCache( + ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache); + // The ConfigureDistributedTpu op is used to start an TPUDriver from // TensorFlow. It should be run on a TPU_SYSTEM device and returns the // connection host:port for the CompilationCacheServer. The diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h index a96cbf38f64..21649050bf7 100644 --- a/tensorflow/core/tpu/tpu_config_c_api.h +++ b/tensorflow/core/tpu/tpu_config_c_api.h @@ -35,8 +35,8 @@ extern "C" { TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( const size_t num_cores_per_host_size, const int32_t* num_cores_per_host, - size_t* host_config_output_size, char** host_config_output, - TF_Status* status); + void* tpu_compilation_cache_interface, size_t* host_config_output_size, + char** host_config_output, TF_Status* status); TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( const size_t num_hosts, const size_t num_cores_per_host,