Add Remote Cache support for POD use cases.

PiperOrigin-RevId: 329862374
Change-Id: I31cd4841e0b0f08d9c09bc0cec0a5fd1abe6dc13
This commit is contained in:
Henry Tan 2020-09-02 23:39:39 -07:00 committed by TensorFlower Gardener
parent 5104953f4c
commit 1a16406bcd
12 changed files with 281 additions and 85 deletions

View File

@ -72,7 +72,6 @@ load(
"if_ios",
"if_mobile",
"if_not_windows",
"if_tpu",
"tf_android_core_proto_headers",
"tf_cc_test",
"tf_cc_test_mkl",
@ -117,6 +116,7 @@ load(
"tf_protos_all_impl",
"tf_protos_grappler_impl",
"tf_protos_profiler_impl",
"tf_tpu_dependencies",
)
load(
"//tensorflow/core/platform:rules_cc.bzl",
@ -1086,9 +1086,7 @@ cc_library(
]) + if_tensorrt([
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
]) + if_tpu([
"//tensorflow/core/tpu/kernels",
]),
]) + tf_tpu_dependencies(),
)
cc_library(

View File

@ -43,6 +43,7 @@ load(
_tf_py_clif_cc = "tf_py_clif_cc",
_tf_pyclif_proto_library = "tf_pyclif_proto_library",
_tf_resource_deps = "tf_resource_deps",
_tf_tpu_dependencies = "tf_tpu_dependencies",
_tf_windows_aware_platform_deps = "tf_windows_aware_platform_deps",
)
@ -88,3 +89,4 @@ tf_py_clif_cc = _tf_py_clif_cc
tf_pyclif_proto_library = _tf_pyclif_proto_library
tf_resource_deps = _tf_resource_deps
tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
tf_tpu_dependencies = _tf_tpu_dependencies

View File

@ -1,7 +1,7 @@
# Platform-specific build configurations.
load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows")
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu")
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
@ -800,3 +800,6 @@ def if_llvm_system_z_available(then, otherwise = []):
"//tensorflow:linux_s390x": then,
"//conditions:default": otherwise,
})
def tf_tpu_dependencies():
return if_tpu(["//tensorflow/core/tpu/kernels"])

View File

@ -99,13 +99,22 @@ tf_kernel_library(
name = "tpu_configuration_ops",
srcs = ["tpu_configuration_ops.cc"],
hdrs = ["tpu_configuration_ops.h"],
deps = [
copts = select({
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
DEFAULT: [],
}),
deps = select({
WITH_TPU_SUPPORT: [":tpu_util"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
}) + [
":tpu_compilation_cache_factory",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_local_lookup",
":tpu_compilation_cache_lookup",
":tpu_compilation_cache_rpc_lookup",
":tpu_mesh_state_interface",
":tpu_op_consts",
":tpu_pod_state",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/xla:util",
@ -116,6 +125,7 @@ tf_kernel_library(
"//tensorflow/core/tpu:tpu_config_c_api",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/tpu:proto_helper",
],
alwayslink = 1,
@ -447,6 +457,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
tf_grpc_cc_dependency(),
],
)
@ -505,10 +516,18 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/tpu:tpu_api",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
tf_grpc_cc_dependency(),
],
alwayslink = 1,
)
# An alias for
cc_library(
name = "tpu_compilation_cache_cc_proto",
deps = [":tpu_compilation_cache_proto_cc"],
)
cc_library(
name = "tpu_compilation_cache_rpc_support_hdrs",
hdrs = ["tpu_compilation_cache_rpc_support.h"],
@ -518,7 +537,7 @@ cc_library(
}),
deps = select({
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], # build_cleaner: keep
}) + [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_interface",
@ -606,7 +625,7 @@ cc_library(
}),
deps = select({
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
}) + [
":tpu_compilation_cache_common_proto_cc",
tf_grpc_cc_dependency(),
@ -628,7 +647,7 @@ cc_library(
],
DEFAULT: [
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", # build_cleaner: keep
],
}) + [
":tpu_compilation_cache_common_proto_cc",
@ -939,10 +958,14 @@ cc_library(
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
DEFAULT: [],
}),
deps = [
deps = select({
WITH_TPU_SUPPORT: [":tpu_util"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
}) + [
":tpu_compilation_cache_service",
":tpu_util",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core:framework",
tf_grpc_cc_dependency(),
],
)

View File

@ -27,8 +27,10 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
@ -37,7 +39,6 @@ limitations under the License.
namespace tensorflow {
namespace {
Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
tpu::TpuMeshStateInterface** state) {
if (!rmgr->Lookup(rmgr->default_container(),
@ -69,7 +70,6 @@ Status DeleteIfExists(ResourceMgr* resource_manager,
VLOG(1) << "Error removing resource " << resource_name << " : " << status;
return status;
}
} // namespace
Status CreateTpuCompilationCache(
@ -82,36 +82,39 @@ Status CreateTpuCompilationCache(
});
}
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "ConfigureDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
OpKernelContext* ctx) {
std::vector<int32_t> num_devices_per_host;
int chips_per_host = -1;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& input_tensor = ctx->input(i);
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
errors::InvalidArgument("Input ", i, " should be a scalar but has ",
input_tensor.dims(), " dimensions"));
if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
return errors::InvalidArgument("Input ", i,
" should be a scalar but has ",
input_tensor.dims(), " dimensions");
}
if (chips_per_host == -1) {
chips_per_host = input_tensor.scalar<int32_t>()();
} else {
OP_REQUIRES(
ctx, chips_per_host == input_tensor.scalar<int32>()(),
errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
" TPU chips but host 0 has ", chips_per_host));
if (chips_per_host != input_tensor.scalar<int32>()()) {
return errors::Internal("Host ", i, " has ",
input_tensor.scalar<int32>()(),
" TPU chips but host 0 has ", chips_per_host);
}
}
num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
}
return num_devices_per_host;
}
TF_Status* status = TF_NewStatus();
size_t host_config_output_size;
char* host_config_output;
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "ConfigureDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
auto* rmgr = GetTPUConfigResourceMgr();
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
ConstructDevicesPerHost(ctx);
OP_REQUIRES_OK(ctx, num_devices_per_host.status());
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
// Create the subgraph compilation cache and put it in the local resource
// manager.
@ -119,9 +122,13 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
core::ScopedUnref compilation_cache_ref(compilation_cache);
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
num_devices_per_host.size(), num_devices_per_host.data(),
compilation_cache, &host_config_output_size, &host_config_output, status);
std::string host_config_output;
OP_REQUIRES_OK(
ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
&host_config_output));
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
OP_REQUIRES_OK(
@ -130,13 +137,7 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
ctx_output->scalar<tstring>()() =
std::string(host_config_output, host_config_output_size);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
ctx_output->scalar<tstring>()() = std::move(host_config_output);
VLOG(1) << "ConfigureDistributedTpuOp done";
}
@ -186,30 +187,39 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
mapping_arg.push_back(mapping[i].data());
}
TF_Status* status = TF_NewStatus();
size_t tpu_topology_output_size;
char* tpu_topology_output;
tpu::TpuMeshStateInterface* mesh_state;
auto* rmgr = GetTPUConfigResourceMgr();
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
core::ScopedUnref mesh_state_unref(mesh_state);
// TODO(b/166858751): this code to check if `TpuPodState` exists is ported
// from a legacy library that may have staled. A candidate for cleanup.
TpuPodState* pod_state;
OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
core::ScopedUnref pod_state_unref(pod_state);
size_t tpu_topology_output_size;
char* tpu_topology_output = nullptr;
TF_Status* status = TF_NewStatus();
auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
tpu_topology_output);
});
auto* mesh_common_state = mesh_state->mesh_common_state();
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
num_hosts, num_devices_per_host,
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
&tpu_topology_output_size, &tpu_topology_output, status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
ctx_output->scalar<tstring>()() =
std::string(tpu_topology_output, tpu_topology_output_size);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
VLOG(1) << "WaitForDistributedTpuOp done";
}
@ -217,17 +227,14 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "ShutdownDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
TF_Status* status = TF_NewStatus();
auto* rmgr = GetTPUConfigResourceMgr();
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
GetTPUConfigResourceMgr(),
tpu::kTpuMeshStateInterfaceResourceName));
tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
OP_REQUIRES_OK(
ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
GetTPUConfigResourceMgr(), tpu::kCompilationCacheResourceName));
OP_REQUIRES_OK(ctx,
DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
rmgr, tpu::kCompilationCacheResourceName));
VLOG(1) << "ShutdownDistributedTpuOp done";
}
@ -239,10 +246,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
auto* rmgr = GetTPUConfigResourceMgr();
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
size_t device_id_output_size;
int32_t* device_id_output;
TF_Status* status = TF_NewStatus();
bool is_master_worker =
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
if (!is_master_worker) {
@ -275,10 +278,18 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
local_compilation_cache = nullptr;
}
TF_Status* status = TF_NewStatus();
size_t device_id_output_size;
int32_t* device_id_output = nullptr;
auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
});
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
tpu_host_config.size(), tpu_host_config.data(),
enable_whole_mesh_compilations_, local_compilation_cache,
&device_id_output_size, &device_id_output, status);
enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
&device_id_output, status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
if (local_compilation_cache != nullptr) {
local_compilation_cache->Unref();
@ -289,6 +300,30 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, rmgr->Create(rmgr->default_container(),
tpu::kCompiledProtoCacheResourceName, proto_lookup));
} else {
int64_t cache_size_bytes;
tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
&cache_size_bytes);
char* server_address_output = nullptr;
auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
server_address_output);
});
size_t server_address_output_size;
tpu::ConfigApiFn()
->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
tpu_host_config.size(), tpu_host_config.data(),
&server_address_output_size, &server_address_output, status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
std::string server_address(server_address_output,
server_address_output_size);
tpu::TpuCompilationCacheLookup* proto_lookup =
new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
OP_REQUIRES_OK(
ctx, rmgr->Create(rmgr->default_container(),
tpu::kCompiledProtoCacheResourceName, proto_lookup));
}
Tensor* ctx_output;
@ -301,10 +336,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
ctx_output->flat<int32>()(i) = device_id_output[i];
}
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
VLOG(1) << "InitializeHostForDistributedTpuOp done";
}

View File

@ -15,14 +15,22 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
#include <stdint.h>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
Status CreateTpuCompilationCache(
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
OpKernelContext* ctx);
// The ConfigureDistributedTpu op is used to start an TPUDriver from
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
// connection host:port for the CompilationCacheServer. The

View File

@ -14,12 +14,78 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/tpu/tpu_api.h"
#if defined(LIBTFTPU)
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#else
#include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara"
#endif
namespace tensorflow {
const char kTpuPodStateResourceName[] = "tpu_pod_state";
namespace {
Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
TF_Status* status = TF_NewStatus();
char* server_address_output = nullptr;
auto cleanup = xla::MakeCleanup([&status, &server_address_output]() {
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
server_address_output);
});
size_t server_address_output_size;
*serving_port = -1;
tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
&server_address_output_size, &server_address_output, serving_port,
status);
CHECK_NE(*serving_port, -1);
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
return Status::OK();
}
// Attempt to delete resource_name from resource_manager's default_container.
// Returns OK if the deletion succeeded, or if the resource was not found. Else
// return the deletion error.
template <class ResourceT>
Status DeleteIfExists(ResourceMgr* resource_manager,
const char* resource_name) {
VLOG(1) << "Removing resource " << resource_name << " if it exists";
Status status = resource_manager->Delete<ResourceT>(
resource_manager->default_container(), resource_name);
if (status.ok()) {
VLOG(1) << "Removed existing resource " << resource_name;
return Status::OK();
}
if (status.code() == error::NOT_FOUND) {
VLOG(1) << "No resource " << resource_name << " to remove";
return Status::OK();
}
VLOG(1) << "Error removing resource " << resource_name << " : " << status;
return status;
}
xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
ConstructCacheService(ResourceMgr* rmgr, int serving_port,
tpu::TpuCompilationCacheInterface* compilation_cache) {
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
#if defined(LIBTFTPU)
server_builder = tpu::CreateServerBuilder(serving_port);
#else
server_builder = tpu::CreateServerBuilderGoogle(serving_port);
#endif
TF_RETURN_IF_ERROR(server_builder.status());
auto cache_service = absl::make_unique<TpuCompilationCacheService>(
server_builder.ValueOrDie().get(), compilation_cache);
cache_service->SetMemoryQuota(1ul << 31); // 2GB
cache_service->Start();
return cache_service;
}
} // namespace
TpuPodState::TpuPodState(
int service_port, std::unique_ptr<TpuCompilationCacheService> cache_service)
: cache_service_(std::move(cache_service)), service_port_(service_port) {}
@ -29,7 +95,7 @@ TpuPodState::~TpuPodState() {
VLOG(1) << "Shutting down Compilation Cache Service.";
if (cache_service_->Shutdown(20)) {
if (service_port_ >= 0) {
tpu::RecycleUnusedPort(service_port_);
tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
}
} else {
LOG(ERROR)
@ -67,4 +133,38 @@ bool HasTPUPodState(const ResourceMgr* rmgr) {
return true;
}
Status ConstructTpuPodState(
ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
tpu::TpuCompilationCacheInterface* compilation_cache,
std::string* host_config_proto) {
TF_Status* status = TF_NewStatus();
auto status_cleanup =
xla::MakeCleanup([&status]() { TF_DeleteStatus(status); });
int serving_port;
std::string server_address;
TF_RETURN_IF_ERROR(GetServerAddressAndPort(&server_address, &serving_port));
char* host_config_output = nullptr;
auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() {
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
});
size_t host_config_output_size;
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
num_devices_per_host.size(), num_devices_per_host.data(),
server_address.size(), server_address.data(), &host_config_output_size,
&host_config_output, status);
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
*host_config_proto = std::string(host_config_output, host_config_output_size);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<TpuCompilationCacheService> cache_service,
ConstructCacheService(rmgr, serving_port, compilation_cache));
// Delete TpuPodState if it exists, and recreate below.
TF_RETURN_IF_ERROR(
DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
return rmgr->Create(rmgr->default_container(), kTpuPodStateResourceName,
new TpuPodState(serving_port, std::move(cache_service)));
}
} // namespace tensorflow

View File

@ -15,7 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
#include "grpcpp/server_builder.h"
#include <string>
#include <vector>
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
@ -49,6 +51,11 @@ Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state);
// manager.
bool HasTPUPodState(const ResourceMgr* rmgr);
// Construct TpuPodState.
Status ConstructTpuPodState(
ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
tpu::TpuCompilationCacheInterface* compilation_cache,
std::string* host_config_proto);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/tpu/tpu_api.h"
@ -97,8 +98,13 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
return Status::OK();
}
void RecycleUnusedPort(int port) {
UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(port);
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
int serving_port) {
auto server_builder = absl::make_unique<::grpc::ServerBuilder>();
server_builder->AddListeningPort(
absl::StrFormat("[::]:%d", serving_port),
::grpc::InsecureServerCredentials()); // NOLINT
return std::move(server_builder);
}
} // namespace tpu
} // namespace tensorflow

View File

@ -15,9 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
#include <memory>
#include <string>
#include <vector>
#include "grpcpp/server_builder.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -55,10 +57,9 @@ Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
std::vector<TensorShape>* shapes);
// We only recycle ports which were given to us by the portserver. For ports
// we obtained through local trial-and-error, there is no reason to expect the
// port to remain available after it is unbound.
void RecycleUnusedPort(int port);
// Creates gRPC ServerBuilder.
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
int serving_port);
} // namespace tpu
} // namespace tensorflow

View File

@ -32,8 +32,9 @@ extern "C" {
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
void* tpu_compilation_cache_interface, size_t* host_config_output_size,
char** host_config_output, TF_Status* status);
size_t server_address_size, const char* server_address,
size_t* host_config_output_size, char** host_config_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,
@ -42,11 +43,9 @@ TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
size_t* tpu_topology_output_size, char** tpu_topology_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
const size_t tpu_host_config_size, const char* tpu_host_config,
const bool enable_whole_mesh_compilations, void* local_compilation_cache,
const bool enable_whole_mesh_compilations, bool is_master_worker,
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
@ -65,12 +64,22 @@ TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
int64_t* cache_size_in_bytes);
TFTPU_CAPI_EXPORT
void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
size_t tpu_host_config_size, const char* tpu_host_config,
size_t* server_address_output_size, char** server_address_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
size_t* server_address_output_size, char** server_address_output,
int* port_output, TF_Status* status);
}
struct TfTpu_ConfigApiFn {
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
@ -79,6 +88,10 @@ struct TfTpu_ConfigApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
TFTPU_ADD_FN_IN_STRUCT(
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
};
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_

View File

@ -11,7 +11,6 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
@ -20,6 +19,11 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit);
TFTPU_SET_FN(config_fn,
TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
TFTPU_SET_FN(config_fn,
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort);
return tensorflow::Status::OK();
}