From 1a16406bcdfb5d0a231dd7e527a63be79f9c83f0 Mon Sep 17 00:00:00 2001 From: Henry Tan Date: Wed, 2 Sep 2020 23:39:39 -0700 Subject: [PATCH] Add Remote Cache support for POD use cases. PiperOrigin-RevId: 329862374 Change-Id: I31cd4841e0b0f08d9c09bc0cec0a5fd1abe6dc13 --- tensorflow/core/BUILD | 6 +- tensorflow/core/platform/build_config.bzl | 2 + .../core/platform/default/build_config.bzl | 5 +- tensorflow/core/tpu/kernels/BUILD | 37 ++++- .../core/tpu/kernels/tpu_configuration_ops.cc | 145 +++++++++++------- .../core/tpu/kernels/tpu_configuration_ops.h | 8 + tensorflow/core/tpu/kernels/tpu_pod_state.cc | 104 ++++++++++++- tensorflow/core/tpu/kernels/tpu_pod_state.h | 9 +- tensorflow/core/tpu/kernels/tpu_util.cc | 10 +- tensorflow/core/tpu/kernels/tpu_util.h | 9 +- tensorflow/core/tpu/tpu_config_c_api.h | 25 ++- tensorflow/core/tpu/tpu_library_init_fns.inc | 6 +- 12 files changed, 281 insertions(+), 85 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b1ebcdbe5b9..8970fce1460 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,7 +72,6 @@ load( "if_ios", "if_mobile", "if_not_windows", - "if_tpu", "tf_android_core_proto_headers", "tf_cc_test", "tf_cc_test_mkl", @@ -117,6 +116,7 @@ load( "tf_protos_all_impl", "tf_protos_grappler_impl", "tf_protos_profiler_impl", + "tf_tpu_dependencies", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -1086,9 +1086,7 @@ cc_library( ]) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels", "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", - ]) + if_tpu([ - "//tensorflow/core/tpu/kernels", - ]), + ]) + tf_tpu_dependencies(), ) cc_library( diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl index 3bfbe617122..cd902ac3353 100644 --- a/tensorflow/core/platform/build_config.bzl +++ b/tensorflow/core/platform/build_config.bzl @@ -43,6 +43,7 @@ load( _tf_py_clif_cc = "tf_py_clif_cc", _tf_pyclif_proto_library = "tf_pyclif_proto_library", _tf_resource_deps = "tf_resource_deps", + _tf_tpu_dependencies = "tf_tpu_dependencies", _tf_windows_aware_platform_deps = "tf_windows_aware_platform_deps", ) @@ -88,3 +89,4 @@ tf_py_clif_cc = _tf_py_clif_cc tf_pyclif_proto_library = _tf_pyclif_proto_library tf_resource_deps = _tf_resource_deps tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps +tf_tpu_dependencies = _tf_tpu_dependencies diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 9f84b9205f1..78191bff8f9 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -1,7 +1,7 @@ # Platform-specific build configurations. load("@com_google_protobuf//:protobuf.bzl", "proto_gen") -load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu") load("//tensorflow/core/platform:build_config_root.bzl", "if_static") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") @@ -800,3 +800,6 @@ def if_llvm_system_z_available(then, otherwise = []): "//tensorflow:linux_s390x": then, "//conditions:default": otherwise, }) + +def tf_tpu_dependencies(): + return if_tpu(["//tensorflow/core/tpu/kernels"]) diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index c47fdc0f9d2..f35f7151222 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -99,13 +99,22 @@ tf_kernel_library( name = "tpu_configuration_ops", srcs = ["tpu_configuration_ops.cc"], hdrs = ["tpu_configuration_ops.h"], - deps = [ + copts = select({ + WITH_TPU_SUPPORT: ["-DLIBTFTPU"], + DEFAULT: [], + }), + deps = select({ + WITH_TPU_SUPPORT: [":tpu_util"], + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"], + }) + [ ":tpu_compilation_cache_factory", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_local_lookup", ":tpu_compilation_cache_lookup", + ":tpu_compilation_cache_rpc_lookup", ":tpu_mesh_state_interface", ":tpu_op_consts", + ":tpu_pod_state", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/compiler/xla:util", @@ -116,6 +125,7 @@ tf_kernel_library( "//tensorflow/core/tpu:tpu_config_c_api", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/tpu:proto_helper", ], alwayslink = 1, @@ -447,6 +457,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + tf_grpc_cc_dependency(), ], ) @@ -505,10 +516,18 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/tpu:tpu_api", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + tf_grpc_cc_dependency(), ], alwayslink = 1, ) +# An alias for +cc_library( + name = "tpu_compilation_cache_cc_proto", + deps = [":tpu_compilation_cache_proto_cc"], +) + cc_library( name = "tpu_compilation_cache_rpc_support_hdrs", hdrs = ["tpu_compilation_cache_rpc_support.h"], @@ -518,7 +537,7 @@ cc_library( }), deps = select({ WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], # build_cleaner: keep }) + [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_interface", @@ -606,7 +625,7 @@ cc_library( }), deps = select({ WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], }) + [ ":tpu_compilation_cache_common_proto_cc", tf_grpc_cc_dependency(), @@ -628,7 +647,7 @@ cc_library( ], DEFAULT: [ "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep + "//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", # build_cleaner: keep ], }) + [ ":tpu_compilation_cache_common_proto_cc", @@ -939,10 +958,14 @@ cc_library( WITH_TPU_SUPPORT: ["-DLIBTFTPU"], DEFAULT: [], }), - deps = [ + deps = select({ + WITH_TPU_SUPPORT: [":tpu_util"], + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"], + }) + [ ":tpu_compilation_cache_service", - ":tpu_util", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core/tpu:tpu_api", "//tensorflow/core:framework", - tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 5a8c283c7c2..271a9697f18 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -27,8 +27,10 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" +#include "tensorflow/core/tpu/kernels/tpu_pod_state.h" #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/core/tpu/tpu_config_c_api.h" #include "tensorflow/core/tpu/tpu_configuration.h" @@ -37,7 +39,6 @@ limitations under the License. namespace tensorflow { namespace { - Status GetTpuMeshStateInterface(const ResourceMgr* rmgr, tpu::TpuMeshStateInterface** state) { if (!rmgr->Lookup(rmgr->default_container(), @@ -69,7 +70,6 @@ Status DeleteIfExists(ResourceMgr* resource_manager, VLOG(1) << "Error removing resource " << resource_name << " : " << status; return status; } - } // namespace Status CreateTpuCompilationCache( @@ -82,36 +82,39 @@ Status CreateTpuCompilationCache( }); } -void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "ConfigureDistributedTpuOp"; - XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp"); - +xla::StatusOr> ConstructDevicesPerHost( + OpKernelContext* ctx) { std::vector num_devices_per_host; int chips_per_host = -1; for (int i = 0; i < ctx->num_inputs(); ++i) { const Tensor& input_tensor = ctx->input(i); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(input_tensor.shape()), - errors::InvalidArgument("Input ", i, " should be a scalar but has ", - input_tensor.dims(), " dimensions")); + if (!TensorShapeUtils::IsScalar(input_tensor.shape())) { + return errors::InvalidArgument("Input ", i, + " should be a scalar but has ", + input_tensor.dims(), " dimensions"); + } if (chips_per_host == -1) { chips_per_host = input_tensor.scalar()(); } else { - OP_REQUIRES( - ctx, chips_per_host == input_tensor.scalar()(), - errors::Internal("Host ", i, " has ", input_tensor.scalar()(), - " TPU chips but host 0 has ", chips_per_host)); + if (chips_per_host != input_tensor.scalar()()) { + return errors::Internal("Host ", i, " has ", + input_tensor.scalar()(), + " TPU chips but host 0 has ", chips_per_host); + } } num_devices_per_host.push_back(input_tensor.scalar()()); } + return num_devices_per_host; +} - TF_Status* status = TF_NewStatus(); - size_t host_config_output_size; - char* host_config_output; +void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "ConfigureDistributedTpuOp"; + XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp"); - auto* rmgr = GetTPUConfigResourceMgr(); - OP_REQUIRES_OK(ctx, DeleteIfExists( - rmgr, tpu::kTpuMeshStateInterfaceResourceName)); + xla::StatusOr> num_devices_per_host = + ConstructDevicesPerHost(ctx); + OP_REQUIRES_OK(ctx, num_devices_per_host.status()); + ResourceMgr* rmgr = GetTPUConfigResourceMgr(); // Create the subgraph compilation cache and put it in the local resource // manager. @@ -119,9 +122,13 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache)); core::ScopedUnref compilation_cache_ref(compilation_cache); - tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn( - num_devices_per_host.size(), num_devices_per_host.data(), - compilation_cache, &host_config_output_size, &host_config_output, status); + std::string host_config_output; + OP_REQUIRES_OK( + ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache, + &host_config_output)); + + OP_REQUIRES_OK(ctx, DeleteIfExists( + rmgr, tpu::kTpuMeshStateInterfaceResourceName)); auto* tpu_mesh = tpu::TpuMeshStateInterface::Create(); OP_REQUIRES_OK( @@ -130,13 +137,7 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { Tensor* ctx_output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); - ctx_output->scalar()() = - std::string(host_config_output, host_config_output_size); - - OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); - TF_DeleteStatus(status); - - tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output); + ctx_output->scalar()() = std::move(host_config_output); VLOG(1) << "ConfigureDistributedTpuOp done"; } @@ -186,30 +187,39 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) { mapping_arg.push_back(mapping[i].data()); } - TF_Status* status = TF_NewStatus(); - size_t tpu_topology_output_size; - char* tpu_topology_output; - tpu::TpuMeshStateInterface* mesh_state; auto* rmgr = GetTPUConfigResourceMgr(); OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state)); core::ScopedUnref mesh_state_unref(mesh_state); + // TODO(b/166858751): this code to check if `TpuPodState` exists is ported + // from a legacy library that may have staled. A candidate for cleanup. + TpuPodState* pod_state; + OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state)); + core::ScopedUnref pod_state_unref(pod_state); + + size_t tpu_topology_output_size; + char* tpu_topology_output = nullptr; + TF_Status* status = TF_NewStatus(); + auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() { + TF_DeleteStatus(status); + tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn( + tpu_topology_output); + }); + auto* mesh_common_state = mesh_state->mesh_common_state(); tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn( num_hosts, num_devices_per_host, const_cast(mapping_arg.data()), mesh_common_state, &tpu_topology_output_size, &tpu_topology_output, status); + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + Tensor* ctx_output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); ctx_output->scalar()() = std::string(tpu_topology_output, tpu_topology_output_size); - OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); - TF_DeleteStatus(status); - tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output); - VLOG(1) << "WaitForDistributedTpuOp done"; } @@ -217,17 +227,14 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) { VLOG(1) << "ShutdownDistributedTpuOp"; XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp"); - TF_Status* status = TF_NewStatus(); + auto* rmgr = GetTPUConfigResourceMgr(); OP_REQUIRES_OK(ctx, DeleteIfExists( - GetTPUConfigResourceMgr(), - tpu::kTpuMeshStateInterfaceResourceName)); - tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status); - OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); - TF_DeleteStatus(status); + rmgr, tpu::kTpuMeshStateInterfaceResourceName)); - OP_REQUIRES_OK( - ctx, DeleteIfExists( - GetTPUConfigResourceMgr(), tpu::kCompilationCacheResourceName)); + OP_REQUIRES_OK(ctx, + DeleteIfExists(rmgr, kTpuPodStateResourceName)); + OP_REQUIRES_OK(ctx, DeleteIfExists( + rmgr, tpu::kCompilationCacheResourceName)); VLOG(1) << "ShutdownDistributedTpuOp done"; } @@ -239,10 +246,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { auto* rmgr = GetTPUConfigResourceMgr(); auto tpu_host_config = ctx->input(0).scalar()(); - size_t device_id_output_size; - int32_t* device_id_output; - TF_Status* status = TF_NewStatus(); - bool is_master_worker = tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn(); if (!is_master_worker) { @@ -275,10 +278,18 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { local_compilation_cache = nullptr; } + TF_Status* status = TF_NewStatus(); + size_t device_id_output_size; + int32_t* device_id_output = nullptr; + auto cleanup = xla::MakeCleanup([&status, &device_id_output]() { + TF_DeleteStatus(status); + tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output); + }); tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn( tpu_host_config.size(), tpu_host_config.data(), - enable_whole_mesh_compilations_, local_compilation_cache, - &device_id_output_size, &device_id_output, status); + enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size, + &device_id_output, status); + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); if (local_compilation_cache != nullptr) { local_compilation_cache->Unref(); @@ -289,6 +300,30 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, rmgr->Create(rmgr->default_container(), tpu::kCompiledProtoCacheResourceName, proto_lookup)); + } else { + int64_t cache_size_bytes; + tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn( + &cache_size_bytes); + + char* server_address_output = nullptr; + auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() { + tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn( + server_address_output); + }); + size_t server_address_output_size; + tpu::ConfigApiFn() + ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn( + tpu_host_config.size(), tpu_host_config.data(), + &server_address_output_size, &server_address_output, status); + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + + std::string server_address(server_address_output, + server_address_output_size); + tpu::TpuCompilationCacheLookup* proto_lookup = + new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes); + OP_REQUIRES_OK( + ctx, rmgr->Create(rmgr->default_container(), + tpu::kCompiledProtoCacheResourceName, proto_lookup)); } Tensor* ctx_output; @@ -301,10 +336,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { ctx_output->flat()(i) = device_id_output[i]; } - OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); - TF_DeleteStatus(status); - tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output); - VLOG(1) << "InitializeHostForDistributedTpuOp done"; } diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h index d0bf5809842..d58712ae3dd 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h @@ -15,14 +15,22 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ #define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ +#include + +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { Status CreateTpuCompilationCache( ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache); +xla::StatusOr> ConstructDevicesPerHost( + OpKernelContext* ctx); + // The ConfigureDistributedTpu op is used to start an TPUDriver from // TensorFlow. It should be run on a TPU_SYSTEM device and returns the // connection host:port for the CompilationCacheServer. The diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc index a45a4d63708..e7f13a657ed 100644 --- a/tensorflow/core/tpu/kernels/tpu_pod_state.cc +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc @@ -14,12 +14,78 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/kernels/tpu_pod_state.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/tpu/tpu_api.h" + +#if defined(LIBTFTPU) #include "tensorflow/core/tpu/kernels/tpu_util.h" +#else +#include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara" +#endif namespace tensorflow { - const char kTpuPodStateResourceName[] = "tpu_pod_state"; +namespace { +Status GetServerAddressAndPort(std::string* server_address, int* serving_port) { + TF_Status* status = TF_NewStatus(); + char* server_address_output = nullptr; + auto cleanup = xla::MakeCleanup([&status, &server_address_output]() { + TF_DeleteStatus(status); + tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn( + server_address_output); + }); + size_t server_address_output_size; + *serving_port = -1; + tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn( + &server_address_output_size, &server_address_output, serving_port, + status); + CHECK_NE(*serving_port, -1); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status)); + return Status::OK(); +} + +// Attempt to delete resource_name from resource_manager's default_container. +// Returns OK if the deletion succeeded, or if the resource was not found. Else +// return the deletion error. +template +Status DeleteIfExists(ResourceMgr* resource_manager, + const char* resource_name) { + VLOG(1) << "Removing resource " << resource_name << " if it exists"; + Status status = resource_manager->Delete( + resource_manager->default_container(), resource_name); + if (status.ok()) { + VLOG(1) << "Removed existing resource " << resource_name; + return Status::OK(); + } + if (status.code() == error::NOT_FOUND) { + VLOG(1) << "No resource " << resource_name << " to remove"; + return Status::OK(); + } + VLOG(1) << "Error removing resource " << resource_name << " : " << status; + return status; +} + +xla::StatusOr> +ConstructCacheService(ResourceMgr* rmgr, int serving_port, + tpu::TpuCompilationCacheInterface* compilation_cache) { + xla::StatusOr> server_builder; +#if defined(LIBTFTPU) + server_builder = tpu::CreateServerBuilder(serving_port); +#else + server_builder = tpu::CreateServerBuilderGoogle(serving_port); +#endif + TF_RETURN_IF_ERROR(server_builder.status()); + + auto cache_service = absl::make_unique( + server_builder.ValueOrDie().get(), compilation_cache); + cache_service->SetMemoryQuota(1ul << 31); // 2GB + cache_service->Start(); + return cache_service; +} +} // namespace + TpuPodState::TpuPodState( int service_port, std::unique_ptr cache_service) : cache_service_(std::move(cache_service)), service_port_(service_port) {} @@ -29,7 +95,7 @@ TpuPodState::~TpuPodState() { VLOG(1) << "Shutting down Compilation Cache Service."; if (cache_service_->Shutdown(20)) { if (service_port_ >= 0) { - tpu::RecycleUnusedPort(service_port_); + tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_); } } else { LOG(ERROR) @@ -67,4 +133,38 @@ bool HasTPUPodState(const ResourceMgr* rmgr) { return true; } +Status ConstructTpuPodState( + ResourceMgr* rmgr, const std::vector& num_devices_per_host, + tpu::TpuCompilationCacheInterface* compilation_cache, + std::string* host_config_proto) { + TF_Status* status = TF_NewStatus(); + auto status_cleanup = + xla::MakeCleanup([&status]() { TF_DeleteStatus(status); }); + + int serving_port; + std::string server_address; + TF_RETURN_IF_ERROR(GetServerAddressAndPort(&server_address, &serving_port)); + + char* host_config_output = nullptr; + auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() { + tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output); + }); + size_t host_config_output_size; + tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn( + num_devices_per_host.size(), num_devices_per_host.data(), + server_address.size(), server_address.data(), &host_config_output_size, + &host_config_output, status); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status)); + *host_config_proto = std::string(host_config_output, host_config_output_size); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr cache_service, + ConstructCacheService(rmgr, serving_port, compilation_cache)); + + // Delete TpuPodState if it exists, and recreate below. + TF_RETURN_IF_ERROR( + DeleteIfExists(rmgr, kTpuPodStateResourceName)); + return rmgr->Create(rmgr->default_container(), kTpuPodStateResourceName, + new TpuPodState(serving_port, std::move(cache_service))); +} } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.h b/tensorflow/core/tpu/kernels/tpu_pod_state.h index 9f37e28f60f..07ad3bee553 100644 --- a/tensorflow/core/tpu/kernels/tpu_pod_state.h +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ #define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ -#include "grpcpp/server_builder.h" +#include +#include + #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h" @@ -49,6 +51,11 @@ Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state); // manager. bool HasTPUPodState(const ResourceMgr* rmgr); +// Construct TpuPodState. +Status ConstructTpuPodState( + ResourceMgr* rmgr, const std::vector& num_devices_per_host, + tpu::TpuCompilationCacheInterface* compilation_cache, + std::string* host_config_proto); } // namespace tensorflow #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc index 837c23c6cf5..6f31d066db5 100644 --- a/tensorflow/core/tpu/kernels/tpu_util.cc +++ b/tensorflow/core/tpu/kernels/tpu_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/kernels/tpu_util.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/tpu/tpu_api.h" @@ -97,8 +98,13 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, return Status::OK(); } -void RecycleUnusedPort(int port) { - UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(port); +xla::StatusOr> CreateServerBuilder( + int serving_port) { + auto server_builder = absl::make_unique<::grpc::ServerBuilder>(); + server_builder->AddListeningPort( + absl::StrFormat("[::]:%d", serving_port), + ::grpc::InsecureServerCredentials()); // NOLINT + return std::move(server_builder); } } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h index 834db31c3d8..d45934f31b6 100644 --- a/tensorflow/core/tpu/kernels/tpu_util.h +++ b/tensorflow/core/tpu/kernels/tpu_util.h @@ -15,9 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ #define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ +#include #include #include +#include "grpcpp/server_builder.h" #include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -55,10 +57,9 @@ Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes, Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, std::vector* shapes); -// We only recycle ports which were given to us by the portserver. For ports -// we obtained through local trial-and-error, there is no reason to expect the -// port to remain available after it is unbound. -void RecycleUnusedPort(int port); +// Creates gRPC ServerBuilder. +xla::StatusOr> CreateServerBuilder( + int serving_port); } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h index 08417dbf907..de4b2e25570 100644 --- a/tensorflow/core/tpu/tpu_config_c_api.h +++ b/tensorflow/core/tpu/tpu_config_c_api.h @@ -32,8 +32,9 @@ extern "C" { TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( const size_t num_cores_per_host_size, const int32_t* num_cores_per_host, - void* tpu_compilation_cache_interface, size_t* host_config_output_size, - char** host_config_output, TF_Status* status); + size_t server_address_size, const char* server_address, + size_t* host_config_output_size, char** host_config_output, + TF_Status* status); TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( const size_t num_hosts, const size_t num_cores_per_host, @@ -42,11 +43,9 @@ TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( size_t* tpu_topology_output_size, char** tpu_topology_output, TF_Status* status); -TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status); - TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( const size_t tpu_host_config_size, const char* tpu_host_config, - const bool enable_whole_mesh_compilations, void* local_compilation_cache, + const bool enable_whole_mesh_compilations, bool is_master_worker, size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status); TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( @@ -65,12 +64,22 @@ TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, TF_Status* status); TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, TF_Status* status); + +TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( + int64_t* cache_size_in_bytes); +TFTPU_CAPI_EXPORT +void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( + size_t tpu_host_config_size, const char* tpu_host_config, + size_t* server_address_output_size, char** server_address_output, + TF_Status* status); +TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( + size_t* server_address_output_size, char** server_address_output, + int* port_output, TF_Status* status); } struct TfTpu_ConfigApiFn { TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); - TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork); TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); @@ -79,6 +88,10 @@ struct TfTpu_ConfigApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); + TFTPU_ADD_FN_IN_STRUCT( + TpuConfigurationApi_CompilationCacheServerAddressFromConfig); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); }; #endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index cb8871a60c5..fde2712a2f0 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -11,7 +11,6 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) { TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork); TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork); - TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork); TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork); TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork); TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork); @@ -20,6 +19,11 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) { TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState); TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost); TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit); + TFTPU_SET_FN(config_fn, + TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); + TFTPU_SET_FN(config_fn, + TpuConfigurationApi_CompilationCacheServerAddressFromConfig); + TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort); return tensorflow::Status::OK(); }