Add Remote Cache support for POD use cases.
PiperOrigin-RevId: 329862374 Change-Id: I31cd4841e0b0f08d9c09bc0cec0a5fd1abe6dc13
This commit is contained in:
parent
5104953f4c
commit
1a16406bcd
@ -72,7 +72,6 @@ load(
|
||||
"if_ios",
|
||||
"if_mobile",
|
||||
"if_not_windows",
|
||||
"if_tpu",
|
||||
"tf_android_core_proto_headers",
|
||||
"tf_cc_test",
|
||||
"tf_cc_test_mkl",
|
||||
@ -117,6 +116,7 @@ load(
|
||||
"tf_protos_all_impl",
|
||||
"tf_protos_grappler_impl",
|
||||
"tf_protos_profiler_impl",
|
||||
"tf_tpu_dependencies",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:rules_cc.bzl",
|
||||
@ -1086,9 +1086,7 @@ cc_library(
|
||||
]) + if_tensorrt([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
|
||||
]) + if_tpu([
|
||||
"//tensorflow/core/tpu/kernels",
|
||||
]),
|
||||
]) + tf_tpu_dependencies(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -43,6 +43,7 @@ load(
|
||||
_tf_py_clif_cc = "tf_py_clif_cc",
|
||||
_tf_pyclif_proto_library = "tf_pyclif_proto_library",
|
||||
_tf_resource_deps = "tf_resource_deps",
|
||||
_tf_tpu_dependencies = "tf_tpu_dependencies",
|
||||
_tf_windows_aware_platform_deps = "tf_windows_aware_platform_deps",
|
||||
)
|
||||
|
||||
@ -88,3 +89,4 @@ tf_py_clif_cc = _tf_py_clif_cc
|
||||
tf_pyclif_proto_library = _tf_pyclif_proto_library
|
||||
tf_resource_deps = _tf_resource_deps
|
||||
tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
|
||||
tf_tpu_dependencies = _tf_tpu_dependencies
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Platform-specific build configurations.
|
||||
|
||||
load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
|
||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows")
|
||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu")
|
||||
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
@ -800,3 +800,6 @@ def if_llvm_system_z_available(then, otherwise = []):
|
||||
"//tensorflow:linux_s390x": then,
|
||||
"//conditions:default": otherwise,
|
||||
})
|
||||
|
||||
def tf_tpu_dependencies():
|
||||
return if_tpu(["//tensorflow/core/tpu/kernels"])
|
||||
|
@ -99,13 +99,22 @@ tf_kernel_library(
|
||||
name = "tpu_configuration_ops",
|
||||
srcs = ["tpu_configuration_ops.cc"],
|
||||
hdrs = ["tpu_configuration_ops.h"],
|
||||
deps = [
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_util"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||
}) + [
|
||||
":tpu_compilation_cache_factory",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_local_lookup",
|
||||
":tpu_compilation_cache_lookup",
|
||||
":tpu_compilation_cache_rpc_lookup",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_op_consts",
|
||||
":tpu_pod_state",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -116,6 +125,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -447,6 +457,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/time",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -505,10 +516,18 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# An alias for
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_cc_proto",
|
||||
deps = [":tpu_compilation_cache_proto_cc"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_rpc_support_hdrs",
|
||||
hdrs = ["tpu_compilation_cache_rpc_support.h"],
|
||||
@ -518,7 +537,7 @@ cc_library(
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], # build_cleaner: keep
|
||||
}) + [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_interface",
|
||||
@ -606,7 +625,7 @@ cc_library(
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
||||
}) + [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
tf_grpc_cc_dependency(),
|
||||
@ -628,7 +647,7 @@ cc_library(
|
||||
],
|
||||
DEFAULT: [
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", # build_cleaner: keep
|
||||
],
|
||||
}) + [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
@ -939,10 +958,14 @@ cc_library(
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = [
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_util"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||
}) + [
|
||||
":tpu_compilation_cache_service",
|
||||
":tpu_util",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core:framework",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
@ -27,8 +27,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
@ -37,7 +39,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
|
||||
tpu::TpuMeshStateInterface** state) {
|
||||
if (!rmgr->Lookup(rmgr->default_container(),
|
||||
@ -69,7 +70,6 @@ Status DeleteIfExists(ResourceMgr* resource_manager,
|
||||
VLOG(1) << "Error removing resource " << resource_name << " : " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status CreateTpuCompilationCache(
|
||||
@ -82,36 +82,39 @@ Status CreateTpuCompilationCache(
|
||||
});
|
||||
}
|
||||
|
||||
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "ConfigureDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
|
||||
|
||||
xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
|
||||
OpKernelContext* ctx) {
|
||||
std::vector<int32_t> num_devices_per_host;
|
||||
int chips_per_host = -1;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& input_tensor = ctx->input(i);
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
|
||||
errors::InvalidArgument("Input ", i, " should be a scalar but has ",
|
||||
input_tensor.dims(), " dimensions"));
|
||||
if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
|
||||
return errors::InvalidArgument("Input ", i,
|
||||
" should be a scalar but has ",
|
||||
input_tensor.dims(), " dimensions");
|
||||
}
|
||||
if (chips_per_host == -1) {
|
||||
chips_per_host = input_tensor.scalar<int32_t>()();
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, chips_per_host == input_tensor.scalar<int32>()(),
|
||||
errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
|
||||
" TPU chips but host 0 has ", chips_per_host));
|
||||
if (chips_per_host != input_tensor.scalar<int32>()()) {
|
||||
return errors::Internal("Host ", i, " has ",
|
||||
input_tensor.scalar<int32>()(),
|
||||
" TPU chips but host 0 has ", chips_per_host);
|
||||
}
|
||||
}
|
||||
num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
|
||||
}
|
||||
return num_devices_per_host;
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
size_t host_config_output_size;
|
||||
char* host_config_output;
|
||||
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "ConfigureDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
|
||||
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||
xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
|
||||
ConstructDevicesPerHost(ctx);
|
||||
OP_REQUIRES_OK(ctx, num_devices_per_host.status());
|
||||
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
|
||||
|
||||
// Create the subgraph compilation cache and put it in the local resource
|
||||
// manager.
|
||||
@ -119,9 +122,13 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
|
||||
core::ScopedUnref compilation_cache_ref(compilation_cache);
|
||||
|
||||
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
compilation_cache, &host_config_output_size, &host_config_output, status);
|
||||
std::string host_config_output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
|
||||
&host_config_output));
|
||||
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||
|
||||
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
|
||||
OP_REQUIRES_OK(
|
||||
@ -130,13 +137,7 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
ctx_output->scalar<tstring>()() =
|
||||
std::string(host_config_output, host_config_output_size);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||
ctx_output->scalar<tstring>()() = std::move(host_config_output);
|
||||
|
||||
VLOG(1) << "ConfigureDistributedTpuOp done";
|
||||
}
|
||||
@ -186,30 +187,39 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
mapping_arg.push_back(mapping[i].data());
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
size_t tpu_topology_output_size;
|
||||
char* tpu_topology_output;
|
||||
|
||||
tpu::TpuMeshStateInterface* mesh_state;
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
||||
core::ScopedUnref mesh_state_unref(mesh_state);
|
||||
|
||||
// TODO(b/166858751): this code to check if `TpuPodState` exists is ported
|
||||
// from a legacy library that may have staled. A candidate for cleanup.
|
||||
TpuPodState* pod_state;
|
||||
OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
|
||||
core::ScopedUnref pod_state_unref(pod_state);
|
||||
|
||||
size_t tpu_topology_output_size;
|
||||
char* tpu_topology_output = nullptr;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
tpu_topology_output);
|
||||
});
|
||||
|
||||
auto* mesh_common_state = mesh_state->mesh_common_state();
|
||||
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||
num_hosts, num_devices_per_host,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
|
||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
ctx_output->scalar<tstring>()() =
|
||||
std::string(tpu_topology_output, tpu_topology_output_size);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
|
||||
|
||||
VLOG(1) << "WaitForDistributedTpuOp done";
|
||||
}
|
||||
|
||||
@ -217,17 +227,14 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "ShutdownDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
GetTPUConfigResourceMgr(),
|
||||
tpu::kTpuMeshStateInterfaceResourceName));
|
||||
tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status);
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
|
||||
GetTPUConfigResourceMgr(), tpu::kCompilationCacheResourceName));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
|
||||
rmgr, tpu::kCompilationCacheResourceName));
|
||||
|
||||
VLOG(1) << "ShutdownDistributedTpuOp done";
|
||||
}
|
||||
@ -239,10 +246,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
|
||||
|
||||
size_t device_id_output_size;
|
||||
int32_t* device_id_output;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
bool is_master_worker =
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
|
||||
if (!is_master_worker) {
|
||||
@ -275,10 +278,18 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
local_compilation_cache = nullptr;
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
size_t device_id_output_size;
|
||||
int32_t* device_id_output = nullptr;
|
||||
auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||
});
|
||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, local_compilation_cache,
|
||||
&device_id_output_size, &device_id_output, status);
|
||||
enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
|
||||
&device_id_output, status);
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
|
||||
if (local_compilation_cache != nullptr) {
|
||||
local_compilation_cache->Unref();
|
||||
@ -289,6 +300,30 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, rmgr->Create(rmgr->default_container(),
|
||||
tpu::kCompiledProtoCacheResourceName, proto_lookup));
|
||||
} else {
|
||||
int64_t cache_size_bytes;
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
|
||||
&cache_size_bytes);
|
||||
|
||||
char* server_address_output = nullptr;
|
||||
auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
server_address_output);
|
||||
});
|
||||
size_t server_address_output_size;
|
||||
tpu::ConfigApiFn()
|
||||
->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
&server_address_output_size, &server_address_output, status);
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
|
||||
std::string server_address(server_address_output,
|
||||
server_address_output_size);
|
||||
tpu::TpuCompilationCacheLookup* proto_lookup =
|
||||
new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, rmgr->Create(rmgr->default_container(),
|
||||
tpu::kCompiledProtoCacheResourceName, proto_lookup));
|
||||
}
|
||||
|
||||
Tensor* ctx_output;
|
||||
@ -301,10 +336,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
ctx_output->flat<int32>()(i) = device_id_output[i];
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||
|
||||
VLOG(1) << "InitializeHostForDistributedTpuOp done";
|
||||
}
|
||||
|
||||
|
@ -15,14 +15,22 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status CreateTpuCompilationCache(
|
||||
ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
|
||||
|
||||
xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
|
||||
OpKernelContext* ctx);
|
||||
|
||||
// The ConfigureDistributedTpu op is used to start an TPUDriver from
|
||||
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
|
||||
// connection host:port for the CompilationCacheServer. The
|
||||
|
@ -14,12 +14,78 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#else
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char kTpuPodStateResourceName[] = "tpu_pod_state";
|
||||
|
||||
namespace {
|
||||
Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
char* server_address_output = nullptr;
|
||||
auto cleanup = xla::MakeCleanup([&status, &server_address_output]() {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
server_address_output);
|
||||
});
|
||||
size_t server_address_output_size;
|
||||
*serving_port = -1;
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
|
||||
&server_address_output_size, &server_address_output, serving_port,
|
||||
status);
|
||||
CHECK_NE(*serving_port, -1);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Attempt to delete resource_name from resource_manager's default_container.
|
||||
// Returns OK if the deletion succeeded, or if the resource was not found. Else
|
||||
// return the deletion error.
|
||||
template <class ResourceT>
|
||||
Status DeleteIfExists(ResourceMgr* resource_manager,
|
||||
const char* resource_name) {
|
||||
VLOG(1) << "Removing resource " << resource_name << " if it exists";
|
||||
Status status = resource_manager->Delete<ResourceT>(
|
||||
resource_manager->default_container(), resource_name);
|
||||
if (status.ok()) {
|
||||
VLOG(1) << "Removed existing resource " << resource_name;
|
||||
return Status::OK();
|
||||
}
|
||||
if (status.code() == error::NOT_FOUND) {
|
||||
VLOG(1) << "No resource " << resource_name << " to remove";
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(1) << "Error removing resource " << resource_name << " : " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
|
||||
ConstructCacheService(ResourceMgr* rmgr, int serving_port,
|
||||
tpu::TpuCompilationCacheInterface* compilation_cache) {
|
||||
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
|
||||
#if defined(LIBTFTPU)
|
||||
server_builder = tpu::CreateServerBuilder(serving_port);
|
||||
#else
|
||||
server_builder = tpu::CreateServerBuilderGoogle(serving_port);
|
||||
#endif
|
||||
TF_RETURN_IF_ERROR(server_builder.status());
|
||||
|
||||
auto cache_service = absl::make_unique<TpuCompilationCacheService>(
|
||||
server_builder.ValueOrDie().get(), compilation_cache);
|
||||
cache_service->SetMemoryQuota(1ul << 31); // 2GB
|
||||
cache_service->Start();
|
||||
return cache_service;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TpuPodState::TpuPodState(
|
||||
int service_port, std::unique_ptr<TpuCompilationCacheService> cache_service)
|
||||
: cache_service_(std::move(cache_service)), service_port_(service_port) {}
|
||||
@ -29,7 +95,7 @@ TpuPodState::~TpuPodState() {
|
||||
VLOG(1) << "Shutting down Compilation Cache Service.";
|
||||
if (cache_service_->Shutdown(20)) {
|
||||
if (service_port_ >= 0) {
|
||||
tpu::RecycleUnusedPort(service_port_);
|
||||
tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
@ -67,4 +133,38 @@ bool HasTPUPodState(const ResourceMgr* rmgr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Status ConstructTpuPodState(
|
||||
ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
|
||||
tpu::TpuCompilationCacheInterface* compilation_cache,
|
||||
std::string* host_config_proto) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto status_cleanup =
|
||||
xla::MakeCleanup([&status]() { TF_DeleteStatus(status); });
|
||||
|
||||
int serving_port;
|
||||
std::string server_address;
|
||||
TF_RETURN_IF_ERROR(GetServerAddressAndPort(&server_address, &serving_port));
|
||||
|
||||
char* host_config_output = nullptr;
|
||||
auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() {
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||
});
|
||||
size_t host_config_output_size;
|
||||
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
server_address.size(), server_address.data(), &host_config_output_size,
|
||||
&host_config_output, status);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
|
||||
*host_config_proto = std::string(host_config_output, host_config_output_size);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<TpuCompilationCacheService> cache_service,
|
||||
ConstructCacheService(rmgr, serving_port, compilation_cache));
|
||||
|
||||
// Delete TpuPodState if it exists, and recreate below.
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
|
||||
return rmgr->Create(rmgr->default_container(), kTpuPodStateResourceName,
|
||||
new TpuPodState(serving_port, std::move(cache_service)));
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
|
||||
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
|
||||
|
||||
@ -49,6 +51,11 @@ Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state);
|
||||
// manager.
|
||||
bool HasTPUPodState(const ResourceMgr* rmgr);
|
||||
|
||||
// Construct TpuPodState.
|
||||
Status ConstructTpuPodState(
|
||||
ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
|
||||
tpu::TpuCompilationCacheInterface* compilation_cache,
|
||||
std::string* host_config_proto);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
@ -97,8 +98,13 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecycleUnusedPort(int port) {
|
||||
UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(port);
|
||||
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
|
||||
int serving_port) {
|
||||
auto server_builder = absl::make_unique<::grpc::ServerBuilder>();
|
||||
server_builder->AddListeningPort(
|
||||
absl::StrFormat("[::]:%d", serving_port),
|
||||
::grpc::InsecureServerCredentials()); // NOLINT
|
||||
return std::move(server_builder);
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -15,9 +15,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -55,10 +57,9 @@ Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
|
||||
Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
|
||||
std::vector<TensorShape>* shapes);
|
||||
|
||||
// We only recycle ports which were given to us by the portserver. For ports
|
||||
// we obtained through local trial-and-error, there is no reason to expect the
|
||||
// port to remain available after it is unbound.
|
||||
void RecycleUnusedPort(int port);
|
||||
// Creates gRPC ServerBuilder.
|
||||
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
|
||||
int serving_port);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -32,8 +32,9 @@ extern "C" {
|
||||
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
void* tpu_compilation_cache_interface, size_t* host_config_output_size,
|
||||
char** host_config_output, TF_Status* status);
|
||||
size_t server_address_size, const char* server_address,
|
||||
size_t* host_config_output_size, char** host_config_output,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
@ -42,11 +43,9 @@ TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
||||
const size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
const bool enable_whole_mesh_compilations, void* local_compilation_cache,
|
||||
const bool enable_whole_mesh_compilations, bool is_master_worker,
|
||||
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
||||
@ -65,12 +64,22 @@ TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
|
||||
int64_t* cache_size_in_bytes);
|
||||
TFTPU_CAPI_EXPORT
|
||||
void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
|
||||
size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
size_t* server_address_output_size, char** server_address_output,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
|
||||
size_t* server_address_output_size, char** server_address_output,
|
||||
int* port_output, TF_Status* status);
|
||||
}
|
||||
|
||||
struct TfTpu_ConfigApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
||||
@ -79,6 +88,10 @@ struct TfTpu_ConfigApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
|
||||
TFTPU_ADD_FN_IN_STRUCT(
|
||||
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
||||
|
@ -11,7 +11,6 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
||||
|
||||
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
@ -20,6 +19,11 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit);
|
||||
TFTPU_SET_FN(config_fn,
|
||||
TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
|
||||
TFTPU_SET_FN(config_fn,
|
||||
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user