[TPU 1VM] Consolidate all TPU ops related APIs into a single file

PiperOrigin-RevId: 338777305
Change-Id: I84fa82f0efeef7f5e64896018da244908ecef11a
This commit is contained in:
Frank Chen 2020-10-23 17:58:15 -07:00 committed by TensorFlower Gardener
parent df7d1daff4
commit 35d27a0335
35 changed files with 516 additions and 819 deletions

View File

@ -110,30 +110,15 @@ cc_library(
],
)
cc_library(
name = "tpu_config_c_api",
hdrs = ["tpu_config_c_api.h"],
deps = [
":libtftpu_header",
"//tensorflow/c:tf_status",
],
alwayslink = True,
)
cc_library(
name = "tpu_api",
srcs = ["tpu_api.cc"],
hdrs = ["tpu_api.h"],
deps = [
":libtftpu_header",
":tpu_config_c_api",
":tpu_executor_api",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
":tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
],
)
@ -160,20 +145,15 @@ cc_library(
":tpu_api",
":tpu_api_dlsym_set_fn",
":tpu_compilation_device",
":tpu_config_c_api",
":tpu_executor_init_fns",
":tpu_library_init_fns",
":tpu_node_device",
":tpu_ops_c_api_hdrs",
":tpu_system_device",
"//tensorflow/core:lib",
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
],
)
@ -292,7 +272,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:stream",
@ -351,3 +331,16 @@ cc_library(
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
],
)
cc_library(
name = "tpu_ops_c_api_hdrs",
srcs = [],
hdrs = ["tpu_ops_c_api.h"],
visibility = ["//visibility:public"],
deps = [
":libtftpu_header",
"//tensorflow/stream_executor/tpu:c_api_decl",
"//tensorflow/stream_executor/tpu:proto_helper",
],
alwayslink = True,
)

View File

@ -158,7 +158,7 @@ cc_library(
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"//tensorflow/core/tpu:tpu_compile_interface",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"//tensorflow/stream_executor/tpu:tpu_topology_external",
"@com_google_absl//absl/algorithm:container",

View File

@ -64,9 +64,9 @@ limitations under the License.
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/tpu_compile_interface.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"

View File

@ -56,7 +56,7 @@ cc_library(
":tpu_op_util",
":tpu_program_group_interface",
":tpu_util",
":tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
":tpu_util_hdrs",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -114,7 +114,7 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_config_c_api",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/lib",
@ -123,19 +123,6 @@ tf_kernel_library(
alwayslink = 1,
)
cc_library(
name = "tpu_compile_c_api_hdrs",
hdrs = ["tpu_compile_c_api.h"],
deps = [
":tpu_mesh_state_c_api_hdrs",
":tpu_program_c_api_hdrs",
":tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:c_api_decl",
],
alwayslink = True,
)
tf_proto_library(
name = "tpu_executable_info_proto",
srcs = ["tpu_executable_info.proto"],
@ -261,24 +248,16 @@ cc_library(
],
)
cc_library(
name = "tpu_mesh_state_c_api_hdrs",
hdrs = ["tpu_mesh_state_c_api.h"],
deps = ["//tensorflow/core/tpu:libtftpu_header"],
alwayslink = True,
)
cc_library(
name = "tpu_mesh_state_interface",
srcs = [],
hdrs = ["tpu_mesh_state_interface.h"],
deps = [
":tpu_compile_c_api_hdrs",
":tpu_mesh_state_c_api_hdrs",
"//tensorflow/compiler/xla/service",
"//tensorflow/core:framework",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
],
)
@ -310,14 +289,11 @@ cc_library(
srcs = ["tpu_program_group.cc"],
hdrs = ["tpu_program_group.h"],
deps = [
":tpu_compile_c_api_hdrs",
":tpu_compile_op_common",
":tpu_compile_op_support",
":tpu_compile_proto_cc",
":tpu_executable_info_proto_cc",
":tpu_mesh_state_c_api_hdrs",
":tpu_mesh_state_interface",
":tpu_program_c_api_hdrs",
":tpu_program_group_interface",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -329,6 +305,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:proto_helper",
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
@ -382,11 +359,9 @@ cc_library(
":tpu_compilation_cache_key",
":tpu_compilation_metrics", # buildcleaner: keep
":tpu_compilation_metrics_hdrs",
":tpu_compile_c_api_hdrs",
":tpu_compile_op_support",
":tpu_mesh_state_interface",
":tpu_op_consts",
":tpu_program_c_api_hdrs",
":tpu_program_group",
":tpu_util",
":trace_util_hdrs",
@ -398,6 +373,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -450,42 +426,18 @@ cc_library(
],
)
cc_library(
name = "tpu_util_c_api_hdrs",
hdrs = ["tpu_util_c_api.h"],
deps = [
":tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:c_api_decl",
"//tensorflow/stream_executor/tpu:proto_helper",
],
alwayslink = True,
)
cc_library(
name = "tpu_program_c_api_hdrs",
hdrs = ["tpu_program_c_api.h"],
deps = [
":tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:c_api_decl",
"//tensorflow/stream_executor/tpu:proto_helper",
],
alwayslink = True,
)
cc_library(
name = "tpu_op_util",
srcs = ["tpu_op_util.cc"],
hdrs = ["tpu_op_util.h"],
deps = [
":tpu_compilation_cache_key",
":tpu_compile_c_api_hdrs",
":tpu_mesh_state_interface",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"@com_google_absl//absl/strings",
],
)
@ -496,7 +448,6 @@ cc_library(
hdrs = ["tpu_util.h"],
deps = [
":tpu_compilation_cache_key",
":tpu_util_c_api_hdrs",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
@ -504,6 +455,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
tf_grpc_cc_dependency(),
@ -548,7 +500,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/tpu:tpu_config_c_api",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:proto_helper",
],
)
@ -665,17 +617,6 @@ cc_library(
],
)
cc_library(
name = "tpu_execute_c_api_hdrs",
hdrs = ["tpu_execute_c_api.h"],
deps = [
":tpu_program_c_api_hdrs",
":tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
],
)
cc_library(
name = "tpu_compile_op_impl",
srcs = ["tpu_compile_op_impl.cc"],
@ -683,11 +624,9 @@ cc_library(
copts = tf_copts(),
deps = [
":tpu_compilation_cache_key",
":tpu_compile_c_api_hdrs",
":tpu_compile_op_common",
":tpu_compile_op_support",
":tpu_compile_proto_cc",
":tpu_mesh_state_c_api_hdrs",
":tpu_program_group",
":tpu_program_group_interface",
":tpu_util",
@ -696,6 +635,7 @@ cc_library(
"//tensorflow/compiler/xla:status",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"@com_google_absl//absl/types:variant",

View File

@ -25,11 +25,10 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/trace_util.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {

View File

@ -34,11 +34,11 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {

View File

@ -434,7 +434,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
// Check if caller has disabled compilation. Set using
// internal::ScopedTpuCompileDisabler.
if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
const std::string error_msg = strings::StrCat(
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
"disabled, session_name(",

View File

@ -1,44 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
#include <stddef.h>
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
extern "C" {
// Compiles Mlir or TF function computation by lowering into HLO IR and returns
// `count` number of TPU programs ready for execution.
// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
// responsible to deallocate both the `XLA_TpuProgram*[]` array and the
// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
// API respectively.
TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status);
struct TfTpu_CompileApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
};
} // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_

View File

@ -39,10 +39,10 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_op_util.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {
@ -543,7 +543,7 @@ void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
ctx->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
if (UtilApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
return;
}

View File

@ -17,9 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {

View File

@ -32,9 +32,9 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
namespace tensorflow {
@ -203,12 +203,11 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
TF_Status* status = TF_NewStatus();
auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
tpu_topology_output);
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
});
auto* mesh_common_state = mesh_state->mesh_common_state();
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(
num_hosts, num_devices_per_host,
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
&tpu_topology_output_size, &tpu_topology_output, status);
@ -247,7 +246,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
bool is_master_worker =
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
if (!is_master_worker) {
// Reset the mesh interface if we are not the master.
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
@ -283,9 +282,9 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
int32_t* device_id_output = nullptr;
auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
});
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
tpu_host_config.size(), tpu_host_config.data(),
enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
&device_id_output, status);
@ -302,16 +301,16 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
tpu::kCompiledProtoCacheResourceName, proto_lookup));
} else {
int64_t cache_size_bytes;
tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
&cache_size_bytes);
char* server_address_output = nullptr;
auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
server_address_output);
});
size_t server_address_output_size;
tpu::ConfigApiFn()
tpu::OpsApiFn()
->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
tpu_host_config.size(), tpu_host_config.data(),
&server_address_output_size, &server_address_output, status);
@ -346,7 +345,7 @@ void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
auto tpu_topology = ctx->input(0).scalar<tstring>()();
TF_Status* status = TF_NewStatus();
tpu::ConfigApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
tpu_topology.data(), status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
@ -362,7 +361,7 @@ void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
TF_Status* status = TF_NewStatus();
int32_t number_of_chips_output = 0;
tpu::ConfigApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
&number_of_chips_output, status);
Tensor* ctx_output;

View File

@ -1,59 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
extern "C" {
typedef struct XLA_DeviceAssignment {
const char* bytes;
size_t size;
} XLA_DeviceAssignment;
TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
size_t arguments_len, SE_DeviceMemoryBase* result,
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
SE_Status* status);
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
XLA_Shape* host_shape, XLA_Shape* device_shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
uint32_t* runtime_input_ptr, size_t runtime_input_size,
int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape,
XLA_Shape* compile_time_shape, SE_Status* status);
struct TfTpu_ExecuteApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
};
} // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_

View File

@ -1,43 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
#include "tensorflow/core/tpu/libtftpu.h"
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
extern "C" {
// Creates a new TPU mesh state object.
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
// Deletes the given TPU `mesh_state` object. Once deleted the object is
// unusable.
TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
// Returns a pointer to an opaque mesh data structure used internally.
TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
XLA_TpuMeshState* mesh_state);
} // extern "C"
struct TfTpu_MeshStateApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
};
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_

View File

@ -19,9 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
@ -39,19 +38,19 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
~TpuMeshStateInterface() override {
if (mesh_state_ != nullptr) {
MeshStateApiFn()->TpuMeshState_FreeFn(mesh_state_);
OpsApiFn()->TpuMeshState_FreeFn(mesh_state_);
}
}
static TpuMeshStateInterface* Create() {
return new TpuMeshStateInterface(MeshStateApiFn()->TpuMeshState_CreateFn());
return new TpuMeshStateInterface(OpsApiFn()->TpuMeshState_CreateFn());
}
const XLA_TpuMeshState* data() const { return mesh_state_; }
tensorflow::TpuMeshCommonState* mesh_common_state() const {
return static_cast<tensorflow::TpuMeshCommonState*>(
MeshStateApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
OpsApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
}
// Returns whether we should include the device assignment as a static field
@ -63,7 +62,7 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
// Static device assignment enables XLA to perform certain optimization when
// all cores are used in the replicated computation.
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
UtilApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
OpsApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
tpu_core_type);
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {
@ -77,7 +77,7 @@ std::string GuaranteedConstFingerprint(
uint64_t fingerprint = 0;
for (const Tensor& constant : guaranteed_constants) {
fingerprint =
tpu::UtilApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
fingerprint, constant.tensor_data().data(),
constant.tensor_data().size());
}
@ -110,7 +110,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
}
}
CompilationCacheKeyResult result =
tpu::UtilApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
CompilationCacheKeyProperty{
config_prefix.data(),
shapes_prefix.data(),
@ -125,7 +125,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
mesh_state.data(),
});
auto buffer_cleanup = gtl::MakeCleanup([result]() {
tpu::UtilApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
});
TpuCompilationCacheKey key;
key.prefix = result.key;

View File

@ -74,12 +74,11 @@ Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
char* server_address_output = nullptr;
auto cleanup = xla::MakeCleanup([&status, &server_address_output]() {
TF_DeleteStatus(status);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
server_address_output);
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(server_address_output);
});
size_t server_address_output_size;
*serving_port = -1;
tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
&server_address_output_size, &server_address_output, serving_port,
status);
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
@ -98,7 +97,7 @@ TpuPodState::~TpuPodState() {
VLOG(1) << "Shutting down Compilation Cache Service.";
if (cache_service_->Shutdown(20)) {
if (service_port_ >= 0) {
tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
tpu::OpsApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
}
} else {
LOG(ERROR)
@ -150,10 +149,10 @@ Status ConstructTpuPodState(
char* host_config_output = nullptr;
auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() {
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
});
size_t host_config_output_size;
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
num_devices_per_host.size(), num_devices_per_host.data(),
server_address.size(), server_address.data(), &host_config_output_size,
&host_config_output, status);

View File

@ -1,135 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
typedef struct XLA_TpuProgram XLA_TpuProgram;
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
struct TpuExecutableSerializedProto {
const char* bytes;
size_t size;
};
struct CompilerMetadataSerializedProto {
const char* bytes;
size_t size;
};
struct HostComputeMetadataSerializedProto {
const char* bytes;
size_t size;
};
extern "C" {
// Creates a new TPU program.
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
// Destroys the `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
// Creates an array of `XLA_TpuProgram*`.
TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
// Destroys an array of `XLA_TpuProgram*`.
TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
// destroyed, it is in an unusable state.
TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
SE_Status* status);
// Gets TPU program size in bytes from the `tpu_program`.
TFTPU_CAPI_EXPORT int64_t
TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
// Logs the summary of current memory state snapshot of the `tpu_program`.
TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
const XLA_TpuProgram* tpu_program);
// Gets TPU program executable info from the `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
SE_Status* status);
// Gets host transfer info proto.
TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
SE_Status* status);
// Gets HLO metadata proto.
TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
SE_Status* status);
// Gets may modify variables boolean value.
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
// Checks if TPU program has sharding.
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
const XLA_TpuProgram* tpu_program);
// Gets TPU program by sharding type. Return value is valid only when the
// `status.status()` returns `OK`.
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
// Gets TPU executable proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
SE_Status* status);
// Gets compilation metadata proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
const XLA_TpuProgram* tpu_program,
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
SE_Status* status);
struct TfTpu_TpuProgramApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
};
} // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_

View File

@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
@ -39,7 +38,7 @@ TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
VLOG(1) << "ConstructExecutableInfo";
TpuSerializedProto serialized_executable_info = {};
StatusHelper status;
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
OpsApiFn()->TpuProgram_GetExecutableInfoFn(
xla_tpu_program, &serialized_executable_info, status.c_status);
TPUExecutableInfoProto executable_info;
if (status.ok()) {
@ -55,7 +54,7 @@ TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
VLOG(1) << "ConstructHostTransferInfo";
TpuSerializedProto serialized_host_transfer_info = {};
StatusHelper status;
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
xla_tpu_program, &serialized_host_transfer_info, status.c_status);
TPUHostTransferInfoProto host_transfer_info;
if (status.ok()) {
@ -71,7 +70,7 @@ xla::HloProto TpuProgramGroup::ConstructHloMetadata(
VLOG(1) << "ConstructHloMetadata";
TpuSerializedProto serialized_hlo_metadata = {};
StatusHelper status;
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(
OpsApiFn()->TpuProgram_GetHloMetadataFn(
xla_tpu_program, &serialized_hlo_metadata, status.c_status);
xla::HloProto hlo_metadata;
if (status.ok()) {
@ -97,8 +96,8 @@ void TpuProgramGroup::Initialize(
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
bool may_modify_variables;
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
xla_tpu_program, &may_modify_variables);
OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
&may_modify_variables);
may_modify_variables_array[i] = may_modify_variables;
executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
@ -114,7 +113,7 @@ void TpuProgramGroup::Initialize(
bool TpuProgramGroup::has_sharding_program() const {
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
return false;
}
}
@ -126,7 +125,7 @@ size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
int64_t TpuProgramGroup::program_size() const {
int64_t total_size = 0;
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
total_size += TpuProgramApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
}
return total_size;
}
@ -134,8 +133,7 @@ int64_t TpuProgramGroup::program_size() const {
bool TpuProgramGroup::LogProgramMemorySummary() {
bool success = true;
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
success &=
TpuProgramApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
}
return success;
}
@ -143,8 +141,7 @@ bool TpuProgramGroup::LogProgramMemorySummary() {
void TpuProgramGroup::UnloadAndDestroyPrograms() {
for (XLA_TpuProgram* tpu_program : tpu_programs_) {
StatusHelper status;
TpuProgramApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program,
status.c_status);
OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
auto s = status.status();
if (!s.ok()) {
LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
@ -208,7 +205,7 @@ bool TpuProgramGroup::may_modify_variables(int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
bool may_modify_variables;
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
&may_modify_variables);
return may_modify_variables;
}
@ -258,7 +255,7 @@ Status TpuProgramGroup::CompileAndBuild(
size_t count = 0;
XLA_TpuProgram** xla_tpu_programs = nullptr;
StatusHelper status;
CompileApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
mesh_state, &xla_tpu_programs,
&count, status.c_status);
if (!status.ok()) {
@ -275,7 +272,7 @@ Status TpuProgramGroup::CompileAndBuild(
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
tpu_program_group->Initialize(
absl::MakeConstSpan(&xla_tpu_programs[0], count));
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
return status.status();
}
@ -284,8 +281,8 @@ std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
std::vector<XLA_TpuProgram*> tpu_programs;
tpu_programs.reserve(tpu_programs_.size());
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn(
if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
tpu_programs_[i], sharding_type));
CHECK_NE(tpu_programs[i], nullptr);
}
@ -300,11 +297,11 @@ Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
StatusHelper status;
auto* xla_tpu_program = TpuProgramApiFn()->TpuProgram_NewFn();
TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
rpc_response_protos[i], xla_tpu_program, status.c_status);
if (!status.status().ok()) {
TpuProgramApiFn()->TpuProgram_FreeFn(xla_tpu_program);
OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
return status.status();
}
tpu_programs[i] = xla_tpu_program;
@ -319,8 +316,8 @@ Status TpuProgramGroup::SerializeExecutable(
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
StatusHelper status;
TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn(
tpu_programs_[index], executable, status.c_status);
OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
executable, status.c_status);
return status.status();
}
@ -329,7 +326,7 @@ Status TpuProgramGroup::SerializeCompilerMetadata(
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
StatusHelper status;
TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn(
OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
tpu_programs_[index], compiler_metadata, status.c_status);
return status.status();
}

View File

@ -27,10 +27,9 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {

View File

@ -1,91 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
// Property for creating compilation cache key.
struct CompilationCacheKeyProperty {
const char* config_prefix;
const char* shapes_prefix;
const char* function_name;
const char* mlir_module;
const int32_t* device_ids;
size_t device_ids_size;
int32_t guaranteed_constants_size;
uint64_t function_library_fingerprint;
int32_t num_cores_per_replica;
int32_t num_replicas;
const XLA_TpuMeshState* mesh_state;
};
// Compilation cache key result returning both the key and a more verbose debug
// version.
struct CompilationCacheKeyResult {
const char* key;
const char* debug_string;
};
extern "C" {
// Checks if whether a TPU compilation is enabled.
TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
// when cancellation is requested for an XLA compile op. Some tests require this
// behavior to be disabled, and we test for this condition with the following
// flag function.
TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
// Returns the number of available TPU core count.
TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
// Recycle unused service port.
TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
// Creates a unique compilation cache `key` used for `put` and `get` operations.
// Returned buffers are heap-allocated and must be owned.
TFTPU_CAPI_EXPORT CompilationCacheKeyResult
TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
// Destroys the CompilationCacheKeyResult returned by calling the
// `TpuCompile_CreateCompilationCacheKey` API.
TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
CompilationCacheKeyResult result);
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
// It promises the value is identical for every execution in the same session
// even if the actual value changes in later executions.
TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
uint64_t fingerprint, const char* data, size_t size);
} // extern "C"
struct TfTpu_UtilApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
};
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_

View File

@ -23,39 +23,9 @@ TfTpu_BaseFn* InitializeApiFn() {
return &base_fn;
}
TfTpu_ConfigApiFn* ConfigApiFn() {
static TfTpu_ConfigApiFn config_api_fn;
return &config_api_fn;
}
TfTpu_MeshStateApiFn* MeshStateApiFn() {
static TfTpu_MeshStateApiFn mesh_state_api_fn;
return &mesh_state_api_fn;
}
TfTpu_CompileApiFn* CompileApiFn() {
static TfTpu_CompileApiFn compile_api_fn;
return &compile_api_fn;
}
TfTpu_ExecuteApiFn* ExecuteApiFn() {
static TfTpu_ExecuteApiFn execute_api_fn;
return &execute_api_fn;
}
TfTpu_TpuProgramApiFn* TpuProgramApiFn() {
static TfTpu_TpuProgramApiFn tpu_program_api_fn;
return &tpu_program_api_fn;
}
TfTpu_NodeContextApiFn* NodeContextApiFn() {
static TfTpu_NodeContextApiFn node_context_api_fn;
return &node_context_api_fn;
}
TfTpu_UtilApiFn* UtilApiFn() {
static TfTpu_UtilApiFn util_api_fn;
return &util_api_fn;
TfTpu_OpsApiFn* OpsApiFn() {
static TfTpu_OpsApiFn ops_api_fn;
return &ops_api_fn;
}
} // namespace tpu

View File

@ -16,33 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_TPU_API_H_
#define TENSORFLOW_CORE_TPU_TPU_API_H_
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_executor_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
namespace tensorflow {
namespace tpu {
TfTpu_BaseFn* InitializeApiFn();
TfTpu_ConfigApiFn* ConfigApiFn();
TfTpu_MeshStateApiFn* MeshStateApiFn();
TfTpu_CompileApiFn* CompileApiFn();
TfTpu_ExecuteApiFn* ExecuteApiFn();
TfTpu_TpuProgramApiFn* TpuProgramApiFn();
TfTpu_NodeContextApiFn* NodeContextApiFn();
TfTpu_UtilApiFn* UtilApiFn();
TfTpu_OpsApiFn* OpsApiFn();
} // namespace tpu
} // namespace tensorflow

View File

@ -17,14 +17,9 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_TPU_API_DLSYM_INITIALIZER_H_
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
// LINT.IfChange
namespace tensorflow {

View File

@ -1,97 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
#define TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
#include <cstddef>
#include <cstdint>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/tpu/libtftpu.h"
typedef struct TpuSerializedProto TpuSerializedProto;
namespace tensorflow {
class TpuMeshCommonState;
} // namespace tensorflow
extern "C" {
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
size_t server_address_size, const char* server_address,
size_t* host_config_output_size, char** host_config_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,
const int32_t** host_ordinal_to_global_core_id_map,
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
size_t* tpu_topology_output_size, char** tpu_topology_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
const size_t tpu_host_config_size, const char* tpu_host_config,
const bool enable_whole_mesh_compilations, bool is_master_worker,
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
const size_t tpu_topology_size, const char* tpu_topology,
TF_Status* status);
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
int32_t* number_of_chips_output, TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
int64_t* cache_size_in_bytes);
TFTPU_CAPI_EXPORT
void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
size_t tpu_host_config_size, const char* tpu_host_config,
size_t* server_address_output_size, char** server_address_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
size_t* server_address_output_size, char** server_address_output,
int* port_output, TF_Status* status);
}
struct TfTpu_ConfigApiFn {
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
TFTPU_ADD_FN_IN_STRUCT(
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
};
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_

View File

@ -107,7 +107,7 @@ xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
XLA_Shape c_host_shape;
XLA_Shape c_device_shape;
ApiConverter::ToC(host_shape, &c_host_shape);
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
&c_host_shape, &c_device_shape);
xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
ApiConverter::Free(&c_host_shape);
@ -119,8 +119,7 @@ int64 ShapeSizeCompact(const xla::Shape& shape) {
XLA_Shape c_shape;
ApiConverter::ToC(shape, &c_shape);
int64 size =
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactFn(
&c_shape);
tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape);
ApiConverter::Free(&c_shape);
return size;
}
@ -129,7 +128,7 @@ int64 ShapeSizeCompactRaw(const xla::Shape& shape) {
XLA_Shape c_shape;
ApiConverter::ToC(shape, &c_shape);
int64 size =
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
&c_shape);
ApiConverter::Free(&c_shape);
return size;
@ -241,8 +240,7 @@ xla::Status UpdateDynamicInputs(
ApiConverter::ToC(runtime_shape, &c_runtime_shape);
ApiConverter::ToC(compile_time_shape, &c_compile_time_shape);
StatusHelper status;
tensorflow::tpu::ExecuteApiFn()
->TpuExecute_RuntimeInputToPaddedDataFn(
tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn(
raw_input_runtime->data(), raw_input_runtime->size(),
padded_data->data(), padded_data->size(), &c_runtime_shape,
&c_compile_time_shape, status.c_status);

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"

View File

@ -6,118 +6,76 @@
namespace {
tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
auto* config_fn = tensorflow::tpu::ConfigApiFn();
tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
auto* ops_api_fn = tensorflow::tpu::OpsApiFn();
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit);
TFTPU_SET_FN(config_fn,
TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(ops_api_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(ops_api_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(ops_api_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeInt32Array);
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_HasTPUPodState);
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpusPerHost);
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpuMemoryLimit);
TFTPU_SET_FN(ops_api_fn,
TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
TFTPU_SET_FN(config_fn,
TFTPU_SET_FN(ops_api_fn,
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort);
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_GetServerAddressAndPort);
return tensorflow::Status::OK();
}
TFTPU_SET_FN(ops_api_fn, TpuMeshState_Create);
TFTPU_SET_FN(ops_api_fn, TpuMeshState_Free);
TFTPU_SET_FN(ops_api_fn, TpuMeshState_MeshCommonState);
tensorflow::Status SetTpuMeshStateStructFns(void* library_handle) {
auto* mesh_state_fn = tensorflow::tpu::MeshStateApiFn();
TFTPU_SET_FN(ops_api_fn, TpuCompile_CompileAndBuild);
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Create);
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Free);
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_MeshCommonState);
TFTPU_SET_FN(ops_api_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_SET_FN(ops_api_fn, HardwareLayout_HostShapeToDeviceShape);
TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSize);
TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompact);
TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompactRaw);
TFTPU_SET_FN(ops_api_fn, TpuExecute_RuntimeInputToPaddedData);
return tensorflow::Status::OK();
}
tensorflow::Status SetCompileStructFn(void* library_handle) {
auto* compile_fn = tensorflow::tpu::CompileApiFn();
TFTPU_SET_FN(compile_fn, TpuCompile_CompileAndBuild);
return tensorflow::Status::OK();
}
tensorflow::Status SetExecuteStructFn(void* library_handle) {
auto* execute_fn = tensorflow::tpu::ExecuteApiFn();
TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompact);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompactRaw);
TFTPU_SET_FN(execute_fn, TpuExecute_RuntimeInputToPaddedData);
return tensorflow::Status::OK();
}
tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn();
TFTPU_SET_FN(tpu_program_fn, TpuProgram_New);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_Free);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_NewArray);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_FreeArray);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_UnloadAndDestroy);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetProgramSize);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_LogProgramMemorySummary);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetExecutableInfo);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata);
TFTPU_SET_FN(tpu_program_fn,
TFTPU_SET_FN(ops_api_fn, TpuProgram_New);
TFTPU_SET_FN(ops_api_fn, TpuProgram_Free);
TFTPU_SET_FN(ops_api_fn, TpuProgram_NewArray);
TFTPU_SET_FN(ops_api_fn, TpuProgram_FreeArray);
TFTPU_SET_FN(ops_api_fn, TpuProgram_UnloadAndDestroy);
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetProgramSize);
TFTPU_SET_FN(ops_api_fn, TpuProgram_LogProgramMemorySummary);
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetExecutableInfo);
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHostTransferInfo);
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHloMetadata);
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetMayModifyVariables);
TFTPU_SET_FN(ops_api_fn, TpuProgram_HasSharding);
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetTpuProgram);
TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeTpuExecutable);
TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeCompilerMetadata);
TFTPU_SET_FN(ops_api_fn,
TpuProgram_DeserializeFromGetTpuProgramResponseProto);
return tensorflow::Status::OK();
}
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Create);
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Free);
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Initialize);
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_StopChipHeartbeats);
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CloseTpuHost);
tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) {
auto* node_context_fn = tensorflow::tpu::NodeContextApiFn();
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Initialize);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost);
return tensorflow::Status::OK();
}
tensorflow::Status SetTpuUtilStructFns(void* library_handle) {
auto* util_fn = tensorflow::tpu::UtilApiFn();
TFTPU_SET_FN(util_fn, TpuTopology_AvailableCoreCount);
TFTPU_SET_FN(util_fn, TpuNetUtil_RecycleUnusedPort);
TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled);
TFTPU_SET_FN(util_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
TFTPU_SET_FN(util_fn, TpuCompile_CreateCompilationCacheKey);
TFTPU_SET_FN(util_fn, TpuCompile_DestroyCompilationCacheKey);
TFTPU_SET_FN(util_fn, TpuCompile_CreateGuaranteedConstFingerprint);
TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount);
TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort);
TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled);
TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
return tensorflow::Status::OK();
}
tensorflow::Status InitializeTpuStructFns(void* library_handle) {
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
TF_RETURN_IF_ERROR(SetExecuteStructFn(library_handle));
TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle));
TF_RETURN_IF_ERROR(SetTpuOpsStructFns(library_handle));
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuUtilStructFns(library_handle));
return tensorflow::Status::OK();
}

View File

@ -0,0 +1,342 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
#include <stddef.h>
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
typedef struct TpuSerializedProto TpuSerializedProto;
namespace tensorflow {
class TpuMeshCommonState;
} // namespace tensorflow
extern "C" {
typedef struct XLA_TpuProgram XLA_TpuProgram;
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
struct TpuExecutableSerializedProto {
const char* bytes;
size_t size;
};
struct CompilerMetadataSerializedProto {
const char* bytes;
size_t size;
};
struct HostComputeMetadataSerializedProto {
const char* bytes;
size_t size;
};
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
typedef struct XLA_DeviceAssignment {
const char* bytes;
size_t size;
} XLA_DeviceAssignment;
// Property for creating compilation cache key.
struct CompilationCacheKeyProperty {
const char* config_prefix;
const char* shapes_prefix;
const char* function_name;
const char* mlir_module;
const int32_t* device_ids;
size_t device_ids_size;
int32_t guaranteed_constants_size;
uint64_t function_library_fingerprint;
int32_t num_cores_per_replica;
int32_t num_replicas;
const XLA_TpuMeshState* mesh_state;
};
// Compilation cache key result returning both the key and a more verbose debug
// version.
struct CompilationCacheKeyResult {
const char* key;
const char* debug_string;
};
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
// Compiles Mlir or TF function computation by lowering into HLO IR and returns
// `count` number of TPU programs ready for execution.
// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
// responsible to deallocate both the `XLA_TpuProgram*[]` array and the
// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
// API respectively.
TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status);
// Creates a new TPU mesh state object.
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
// Deletes the given TPU `mesh_state` object. Once deleted the object is
// unusable.
TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
// Returns a pointer to an opaque mesh data structure used internally.
TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
XLA_TpuMeshState* mesh_state);
TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
size_t arguments_len, SE_DeviceMemoryBase* result,
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
SE_Status* status);
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
XLA_Shape* host_shape, XLA_Shape* device_shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
uint32_t* runtime_input_ptr, size_t runtime_input_size,
int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape,
XLA_Shape* compile_time_shape, SE_Status* status);
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
size_t server_address_size, const char* server_address,
size_t* host_config_output_size, char** host_config_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,
const int32_t** host_ordinal_to_global_core_id_map,
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
size_t* tpu_topology_output_size, char** tpu_topology_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
const size_t tpu_host_config_size, const char* tpu_host_config,
const bool enable_whole_mesh_compilations, bool is_master_worker,
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
const size_t tpu_topology_size, const char* tpu_topology,
TF_Status* status);
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
int32_t* number_of_chips_output, TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
int64_t* cache_size_in_bytes);
TFTPU_CAPI_EXPORT
void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
size_t tpu_host_config_size, const char* tpu_host_config,
size_t* server_address_output_size, char** server_address_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
size_t* server_address_output_size, char** server_address_output,
int* port_output, TF_Status* status);
// Creates a new TPU program.
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
// Destroys the `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
// Creates an array of `XLA_TpuProgram*`.
TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
// Destroys an array of `XLA_TpuProgram*`.
TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
// destroyed, it is in an unusable state.
TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
SE_Status* status);
// Gets TPU program size in bytes from the `tpu_program`.
TFTPU_CAPI_EXPORT int64_t
TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
// Logs the summary of current memory state snapshot of the `tpu_program`.
TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
const XLA_TpuProgram* tpu_program);
// Gets TPU program executable info from the `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
SE_Status* status);
// Gets host transfer info proto.
TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
SE_Status* status);
// Gets HLO metadata proto.
TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
SE_Status* status);
// Gets may modify variables boolean value.
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
// Checks if TPU program has sharding.
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
const XLA_TpuProgram* tpu_program);
// Gets TPU program by sharding type. Return value is valid only when the
// `status.status()` returns `OK`.
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
// Gets TPU executable proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
SE_Status* status);
// Gets compilation metadata proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
const XLA_TpuProgram* tpu_program,
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
SE_Status* status);
// Checks if whether a TPU compilation is enabled.
TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
// when cancellation is requested for an XLA compile op. Some tests require this
// behavior to be disabled, and we test for this condition with the following
// flag function.
TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
// Returns the number of available TPU core count.
TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
// Recycle unused service port.
TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
// Creates a unique compilation cache `key` used for `put` and `get` operations.
// Returned buffers are heap-allocated and must be owned.
TFTPU_CAPI_EXPORT CompilationCacheKeyResult
TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
// Destroys the CompilationCacheKeyResult returned by calling the
// `TpuCompile_CreateCompilationCacheKey` API.
TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
CompilationCacheKeyResult result);
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
// It promises the value is identical for every execution in the same session
// even if the actual value changes in later executions.
TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
uint64_t fingerprint, const char* data, size_t size);
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
SE_Status* status);
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
void TpuNodeContext_CloseTpuHost(SE_Status* status);
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
struct TfTpu_OpsApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
TFTPU_ADD_FN_IN_STRUCT(
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
};
} // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_

View File

@ -195,7 +195,6 @@ cc_library(
deps = [
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_node_context_c_api_hdrs",
":tpu_platform_interface",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",
@ -204,6 +203,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
@ -293,9 +293,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_program_c_api_hdrs",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",

View File

@ -17,8 +17,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
@ -79,8 +79,7 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
run_options.run_options().stream()->implementation());
StatusHelper status;
tensorflow::tpu::ExecuteApiFn()
->TpuExecutable_LoadProgramAndEnqueueToStreamFn(
tensorflow::tpu::OpsApiFn()->TpuExecutable_LoadProgramAndEnqueueToStreamFn(
core_program_, arguments_bases, arguments.size(), &result_base,
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
rng_seed, &c_dev_assign, stream, status.c_status);
@ -96,7 +95,7 @@ Shape TpuExecutable::HostShapeToDeviceShape(const Shape& host_shape) {
XLA_Shape c_host_shape;
XLA_Shape c_device_shape;
ApiConverter::ToC(host_shape, &c_host_shape);
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
&c_host_shape, &c_device_shape);
Shape device_shape = ApiConverter::FromC(&c_device_shape);
ApiConverter::Free(&c_host_shape);
@ -108,7 +107,7 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
XLA_Shape c_shape;
ApiConverter::ToC(shape, &c_shape);
int64 size =
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
ApiConverter::Free(&c_shape);
return size;
}

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
namespace tensorflow {
namespace tpu {
@ -30,40 +30,38 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
int device_ordinal) {
StatusHelper status;
XLA_TpuNodeContext* node_context =
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
status.c_status);
tpu::OpsApiFn()->TpuNodeContext_CreateFn(device_ordinal, status.c_status);
if (!status.status().ok()) {
// TpuNodeContext_CreateFn allocates a new XLA_TpuNodeContext regardless of
// status. It needs to be freed if it's not given to a TpuNodeContext below.
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
tpu::OpsApiFn()->TpuNodeContext_FreeFn(node_context);
return status.status();
}
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
}
TpuNodeContext::~TpuNodeContext() {
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
tpu::OpsApiFn()->TpuNodeContext_FreeFn(node_context_);
}
/* static */
Status TpuNodeContext::StopChipHeartbeats() {
StatusHelper status;
tpu::NodeContextApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
tpu::OpsApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
return status.status();
}
/* static */
Status TpuNodeContext::CloseTpuHost() {
StatusHelper status;
tpu::NodeContextApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
tpu::OpsApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
return status.status();
}
/* static */
Status TpuNodeContext::Initialize(int device_ordinal) {
StatusHelper status;
tpu::NodeContextApiFn()->TpuNodeContext_InitializeFn(device_ordinal,
status.c_status);
tpu::OpsApiFn()->TpuNodeContext_InitializeFn(device_ordinal, status.c_status);
return status.status();
}

View File

@ -24,11 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {

View File

@ -147,7 +147,7 @@ void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) {
Status TpuPlatform::TpusPerHost(int* tpus) {
TF_Status* status = TF_NewStatus();
tpu::ConfigApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
auto ret_status = StatusFromTF_Status(status);
TF_DeleteStatus(status);
return ret_status;
@ -155,7 +155,7 @@ Status TpuPlatform::TpusPerHost(int* tpus) {
Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) {
TF_Status* status = TF_NewStatus();
tpu::ConfigApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
reinterpret_cast<int64_t*>(memory_limit), status);
auto ret_status = StatusFromTF_Status(status);
TF_DeleteStatus(status);