diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 85586809014..f6e058a7efe 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -110,30 +110,15 @@ cc_library( ], ) -cc_library( - name = "tpu_config_c_api", - hdrs = ["tpu_config_c_api.h"], - deps = [ - ":libtftpu_header", - "//tensorflow/c:tf_status", - ], - alwayslink = True, -) - cc_library( name = "tpu_api", srcs = ["tpu_api.cc"], hdrs = ["tpu_api.h"], deps = [ ":libtftpu_header", - ":tpu_config_c_api", ":tpu_executor_api", - "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs", + ":tpu_ops_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", - "//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs", ], ) @@ -160,20 +145,15 @@ cc_library( ":tpu_api", ":tpu_api_dlsym_set_fn", ":tpu_compilation_device", - ":tpu_config_c_api", ":tpu_executor_init_fns", ":tpu_library_init_fns", ":tpu_node_device", + ":tpu_ops_c_api_hdrs", ":tpu_system_device", "//tensorflow/core:lib", "//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration", - "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_computation_placer", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", - "//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs", ], ) @@ -292,7 +272,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:stream", @@ -351,3 +331,16 @@ cc_library( "//tensorflow/stream_executor/tpu:tpu_transfer_manager", ], ) + +cc_library( + name = "tpu_ops_c_api_hdrs", + srcs = [], + hdrs = ["tpu_ops_c_api.h"], + visibility = ["//visibility:public"], + deps = [ + ":libtftpu_header", + "//tensorflow/stream_executor/tpu:c_api_decl", + "//tensorflow/stream_executor/tpu:proto_helper", + ], + alwayslink = True, +) diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 36c3b6205e1..4e110f6348a 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -158,7 +158,7 @@ cc_library( "//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/tpu:tpu_compile_interface", "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_platform_interface", "//tensorflow/stream_executor/tpu:tpu_topology_external", "@com_google_absl//absl/algorithm:container", diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index cdf32c54d86..bd20924ff23 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -64,9 +64,9 @@ limitations under the License. #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h" #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h" #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h" -#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/tpu_compile_interface.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 8de50acfd6c..0f183c5de98 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -56,7 +56,7 @@ cc_library( ":tpu_op_util", ":tpu_program_group_interface", ":tpu_util", - ":tpu_util_c_api_hdrs", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", ":tpu_util_hdrs", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -114,7 +114,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:refcount", "//tensorflow/core/tpu:tpu_api", - "//tensorflow/core/tpu:tpu_config_c_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/stream_executor/lib", @@ -123,19 +123,6 @@ tf_kernel_library( alwayslink = 1, ) -cc_library( - name = "tpu_compile_c_api_hdrs", - hdrs = ["tpu_compile_c_api.h"], - deps = [ - ":tpu_mesh_state_c_api_hdrs", - ":tpu_program_c_api_hdrs", - ":tpu_util_c_api_hdrs", - "//tensorflow/core/tpu:libtftpu_header", - "//tensorflow/stream_executor/tpu:c_api_decl", - ], - alwayslink = True, -) - tf_proto_library( name = "tpu_executable_info_proto", srcs = ["tpu_executable_info.proto"], @@ -261,24 +248,16 @@ cc_library( ], ) -cc_library( - name = "tpu_mesh_state_c_api_hdrs", - hdrs = ["tpu_mesh_state_c_api.h"], - deps = ["//tensorflow/core/tpu:libtftpu_header"], - alwayslink = True, -) - cc_library( name = "tpu_mesh_state_interface", srcs = [], hdrs = ["tpu_mesh_state_interface.h"], deps = [ - ":tpu_compile_c_api_hdrs", - ":tpu_mesh_state_c_api_hdrs", "//tensorflow/compiler/xla/service", "//tensorflow/core:framework", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", ], ) @@ -310,14 +289,11 @@ cc_library( srcs = ["tpu_program_group.cc"], hdrs = ["tpu_program_group.h"], deps = [ - ":tpu_compile_c_api_hdrs", ":tpu_compile_op_common", ":tpu_compile_op_support", ":tpu_compile_proto_cc", ":tpu_executable_info_proto_cc", - ":tpu_mesh_state_c_api_hdrs", ":tpu_mesh_state_interface", - ":tpu_program_c_api_hdrs", ":tpu_program_group_interface", "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -329,6 +305,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/stream_executor/tpu:proto_helper", "//tensorflow/stream_executor/tpu:status_helper", "//tensorflow/stream_executor/tpu:tpu_platform_interface", @@ -382,11 +359,9 @@ cc_library( ":tpu_compilation_cache_key", ":tpu_compilation_metrics", # buildcleaner: keep ":tpu_compilation_metrics_hdrs", - ":tpu_compile_c_api_hdrs", ":tpu_compile_op_support", ":tpu_mesh_state_interface", ":tpu_op_consts", - ":tpu_program_c_api_hdrs", ":tpu_program_group", ":tpu_util", ":trace_util_hdrs", @@ -398,6 +373,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -450,42 +426,18 @@ cc_library( ], ) -cc_library( - name = "tpu_util_c_api_hdrs", - hdrs = ["tpu_util_c_api.h"], - deps = [ - ":tpu_mesh_state_c_api_hdrs", - "//tensorflow/core/tpu:libtftpu_header", - "//tensorflow/stream_executor/tpu:c_api_decl", - "//tensorflow/stream_executor/tpu:proto_helper", - ], - alwayslink = True, -) - -cc_library( - name = "tpu_program_c_api_hdrs", - hdrs = ["tpu_program_c_api.h"], - deps = [ - ":tpu_util_c_api_hdrs", - "//tensorflow/core/tpu:libtftpu_header", - "//tensorflow/stream_executor/tpu:c_api_decl", - "//tensorflow/stream_executor/tpu:proto_helper", - ], - alwayslink = True, -) - cc_library( name = "tpu_op_util", srcs = ["tpu_op_util.cc"], hdrs = ["tpu_op_util.h"], deps = [ ":tpu_compilation_cache_key", - ":tpu_compile_c_api_hdrs", ":tpu_mesh_state_interface", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "@com_google_absl//absl/strings", ], ) @@ -496,7 +448,6 @@ cc_library( hdrs = ["tpu_util.h"], deps = [ ":tpu_compilation_cache_key", - ":tpu_util_c_api_hdrs", "//tensorflow/cc:ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", @@ -504,6 +455,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", tf_grpc_cc_dependency(), @@ -548,7 +500,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "//tensorflow/core/tpu:tpu_config_c_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/stream_executor/tpu:proto_helper", ], ) @@ -665,17 +617,6 @@ cc_library( ], ) -cc_library( - name = "tpu_execute_c_api_hdrs", - hdrs = ["tpu_execute_c_api.h"], - deps = [ - ":tpu_program_c_api_hdrs", - ":tpu_util_c_api_hdrs", - "//tensorflow/core/tpu:libtftpu_header", - "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", - ], -) - cc_library( name = "tpu_compile_op_impl", srcs = ["tpu_compile_op_impl.cc"], @@ -683,11 +624,9 @@ cc_library( copts = tf_copts(), deps = [ ":tpu_compilation_cache_key", - ":tpu_compile_c_api_hdrs", ":tpu_compile_op_common", ":tpu_compile_op_support", ":tpu_compile_proto_cc", - ":tpu_mesh_state_c_api_hdrs", ":tpu_program_group", ":tpu_program_group_interface", ":tpu_util", @@ -696,6 +635,7 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_executor", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", "@com_google_absl//absl/types:variant", diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc index 80010d70cd4..7dcc30ed182 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc @@ -25,11 +25,10 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/kernels/trace_util.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h index c3f95e7e09d..655449d6291 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -34,11 +34,11 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" #include "tensorflow/core/tpu/kernels/tpu_program_group.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc index 5ddce57807d..88532c295cb 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc @@ -434,7 +434,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( // Check if caller has disabled compilation. Set using // internal::ScopedTpuCompileDisabler. - if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) { + if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) { const std::string error_msg = strings::StrCat( "[TpuCompilationDisabled]: Compilation cache miss, but compilation " "disabled, session_name(", diff --git a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h deleted file mode 100644 index 07bc49b2167..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ - -#include - -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" -#include "tensorflow/core/tpu/libtftpu.h" -#include "tensorflow/stream_executor/tpu/c_api_decl.h" - -extern "C" { - -// Compiles Mlir or TF function computation by lowering into HLO IR and returns -// `count` number of TPU programs ready for execution. -// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates -// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is -// responsible to deallocate both the `XLA_TpuProgram*[]` array and the -// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` -// API respectively. -TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild( - TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state, - XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status); - -struct TfTpu_CompileApiFn { - TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild); -}; - -} // extern "C" - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index eeb396349bb..2a16075a858 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -39,10 +39,10 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_op_util.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" -#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { @@ -543,7 +543,7 @@ void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) { ctx->cancellation_manager()->get_cancellation_token(); const bool already_cancelled = !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() { - if (UtilApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) { + if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) { return; } diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc index 270c2c53d7a..59d2aa79ace 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_program_group.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h index 3f058683223..f0d731a93ef 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 271a9697f18..32741c9967c 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -32,9 +32,9 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" #include "tensorflow/core/tpu/kernels/tpu_pod_state.h" #include "tensorflow/core/tpu/tpu_api.h" -#include "tensorflow/core/tpu/tpu_config_c_api.h" #include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" namespace tensorflow { @@ -203,12 +203,11 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) { TF_Status* status = TF_NewStatus(); auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() { TF_DeleteStatus(status); - tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn( - tpu_topology_output); + tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output); }); auto* mesh_common_state = mesh_state->mesh_common_state(); - tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn( + tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn( num_hosts, num_devices_per_host, const_cast(mapping_arg.data()), mesh_common_state, &tpu_topology_output_size, &tpu_topology_output, status); @@ -247,7 +246,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { auto tpu_host_config = ctx->input(0).scalar()(); bool is_master_worker = - tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn(); + tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn(); if (!is_master_worker) { // Reset the mesh interface if we are not the master. OP_REQUIRES_OK(ctx, DeleteIfExists( @@ -283,9 +282,9 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { int32_t* device_id_output = nullptr; auto cleanup = xla::MakeCleanup([&status, &device_id_output]() { TF_DeleteStatus(status); - tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output); + tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output); }); - tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn( + tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn( tpu_host_config.size(), tpu_host_config.data(), enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size, &device_id_output, status); @@ -302,16 +301,16 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { tpu::kCompiledProtoCacheResourceName, proto_lookup)); } else { int64_t cache_size_bytes; - tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn( + tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn( &cache_size_bytes); char* server_address_output = nullptr; auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() { - tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn( + tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn( server_address_output); }); size_t server_address_output_size; - tpu::ConfigApiFn() + tpu::OpsApiFn() ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn( tpu_host_config.size(), tpu_host_config.data(), &server_address_output_size, &server_address_output, status); @@ -346,8 +345,8 @@ void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) { auto tpu_topology = ctx->input(0).scalar()(); TF_Status* status = TF_NewStatus(); - tpu::ConfigApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(), - tpu_topology.data(), status); + tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(), + tpu_topology.data(), status); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); TF_DeleteStatus(status); @@ -362,7 +361,7 @@ void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) { TF_Status* status = TF_NewStatus(); int32_t number_of_chips_output = 0; - tpu::ConfigApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn( + tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn( &number_of_chips_output, status); Tensor* ctx_output; diff --git a/tensorflow/core/tpu/kernels/tpu_execute_c_api.h b/tensorflow/core/tpu/kernels/tpu_execute_c_api.h deleted file mode 100644 index 81d23441ddc..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_execute_c_api.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_ - -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" -#include "tensorflow/core/tpu/libtftpu.h" -#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" - -extern "C" { - -typedef struct XLA_DeviceAssignment { - const char* bytes; - size_t size; -} XLA_DeviceAssignment; - -TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( - const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments, - size_t arguments_len, SE_DeviceMemoryBase* result, - SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed, - XLA_DeviceAssignment* device_assignment, SE_Stream* stream, - SE_Status* status); - -TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( - XLA_Shape* host_shape, XLA_Shape* device_shape); -TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); -TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); -TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); - -TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( - uint32_t* runtime_input_ptr, size_t runtime_input_size, - int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape, - XLA_Shape* compile_time_shape, SE_Status* status); - -struct TfTpu_ExecuteApiFn { - TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); - TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); - TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); - TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); - TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); - TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); -}; - -} // extern "C" - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h deleted file mode 100644 index a6434d7d2fd..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ - -#include "tensorflow/core/tpu/libtftpu.h" - -typedef struct XLA_TpuMeshState XLA_TpuMeshState; - -extern "C" { - -// Creates a new TPU mesh state object. -TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create(); - -// Deletes the given TPU `mesh_state` object. Once deleted the object is -// unusable. -TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); - -// Returns a pointer to an opaque mesh data structure used internally. -TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( - XLA_TpuMeshState* mesh_state); - -} // extern "C" - -struct TfTpu_MeshStateApiFn { - TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create); - TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free); - TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState); -}; - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h index 20d1f672c65..0fed2b607ec 100644 --- a/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h @@ -19,9 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { @@ -39,19 +38,19 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase { ~TpuMeshStateInterface() override { if (mesh_state_ != nullptr) { - MeshStateApiFn()->TpuMeshState_FreeFn(mesh_state_); + OpsApiFn()->TpuMeshState_FreeFn(mesh_state_); } } static TpuMeshStateInterface* Create() { - return new TpuMeshStateInterface(MeshStateApiFn()->TpuMeshState_CreateFn()); + return new TpuMeshStateInterface(OpsApiFn()->TpuMeshState_CreateFn()); } const XLA_TpuMeshState* data() const { return mesh_state_; } tensorflow::TpuMeshCommonState* mesh_common_state() const { return static_cast( - MeshStateApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_)); + OpsApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_)); } // Returns whether we should include the device assignment as a static field @@ -63,8 +62,8 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase { // Static device assignment enables XLA to perform certain optimization when // all cores are used in the replicated computation. return metadata.num_cores_per_replica() * metadata.num_replicas() == - UtilApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_, - tpu_core_type); + OpsApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_, + tpu_core_type); } string DebugString() const override { return "TpuMeshStateInterface"; } diff --git a/tensorflow/core/tpu/kernels/tpu_op_util.cc b/tensorflow/core/tpu/kernels/tpu_op_util.cc index 0d02cac7377..01d85dabe47 100644 --- a/tensorflow/core/tpu/kernels/tpu_op_util.cc +++ b/tensorflow/core/tpu/kernels/tpu_op_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { @@ -77,7 +77,7 @@ std::string GuaranteedConstFingerprint( uint64_t fingerprint = 0; for (const Tensor& constant : guaranteed_constants) { fingerprint = - tpu::UtilApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn( + tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn( fingerprint, constant.tensor_data().data(), constant.tensor_data().size()); } @@ -110,7 +110,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey( } } CompilationCacheKeyResult result = - tpu::UtilApiFn()->TpuCompile_CreateCompilationCacheKeyFn( + tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn( CompilationCacheKeyProperty{ config_prefix.data(), shapes_prefix.data(), @@ -125,7 +125,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey( mesh_state.data(), }); auto buffer_cleanup = gtl::MakeCleanup([result]() { - tpu::UtilApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result); + tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result); }); TpuCompilationCacheKey key; key.prefix = result.key; diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc index 898f02b28e9..0c52bff1109 100644 --- a/tensorflow/core/tpu/kernels/tpu_pod_state.cc +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc @@ -74,12 +74,11 @@ Status GetServerAddressAndPort(std::string* server_address, int* serving_port) { char* server_address_output = nullptr; auto cleanup = xla::MakeCleanup([&status, &server_address_output]() { TF_DeleteStatus(status); - tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn( - server_address_output); + tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(server_address_output); }); size_t server_address_output_size; *serving_port = -1; - tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn( + tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn( &server_address_output_size, &server_address_output, serving_port, status); TF_RETURN_IF_ERROR(StatusFromTF_Status(status)); @@ -98,7 +97,7 @@ TpuPodState::~TpuPodState() { VLOG(1) << "Shutting down Compilation Cache Service."; if (cache_service_->Shutdown(20)) { if (service_port_ >= 0) { - tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_); + tpu::OpsApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_); } } else { LOG(ERROR) @@ -150,10 +149,10 @@ Status ConstructTpuPodState( char* host_config_output = nullptr; auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() { - tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output); + tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output); }); size_t host_config_output_size; - tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn( + tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn( num_devices_per_host.size(), num_devices_per_host.data(), server_address.size(), server_address.data(), &host_config_output_size, &host_config_output, status); diff --git a/tensorflow/core/tpu/kernels/tpu_program_c_api.h b/tensorflow/core/tpu/kernels/tpu_program_c_api.h deleted file mode 100644 index d6e46a7c419..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_program_c_api.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_ - -#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" -#include "tensorflow/core/tpu/libtftpu.h" -#include "tensorflow/stream_executor/tpu/c_api_decl.h" -#include "tensorflow/stream_executor/tpu/proto_helper.h" - -typedef struct XLA_TpuProgram XLA_TpuProgram; - -// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. -enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; - -struct TpuExecutableSerializedProto { - const char* bytes; - size_t size; -}; - -struct CompilerMetadataSerializedProto { - const char* bytes; - size_t size; -}; - -struct HostComputeMetadataSerializedProto { - const char* bytes; - size_t size; -}; - -extern "C" { - -// Creates a new TPU program. -TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); - -// Destroys the `tpu_program`. -TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program); - -// Creates an array of `XLA_TpuProgram*`. -TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count); - -// Destroys an array of `XLA_TpuProgram*`. -TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]); - -// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and -// destroyed, it is in an unusable state. -TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, - SE_Status* status); - -// Gets TPU program size in bytes from the `tpu_program`. -TFTPU_CAPI_EXPORT int64_t -TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); - -// Logs the summary of current memory state snapshot of the `tpu_program`. -TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary( - const XLA_TpuProgram* tpu_program); - -// Gets TPU program executable info from the `tpu_program`. -TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo( - const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info, - SE_Status* status); - -// Gets host transfer info proto. -TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo( - const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info, - SE_Status* status); - -// Gets HLO metadata proto. -TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( - const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata, - SE_Status* status); - -// Gets may modify variables boolean value. -TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( - const XLA_TpuProgram* tpu_program, bool* may_modify_variables); - -// Checks if TPU program has sharding. -TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( - const XLA_TpuProgram* tpu_program); - -// Gets TPU program by sharding type. Return value is valid only when the -// `status.status()` returns `OK`. -TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( - XLA_TpuProgram* tpu_program, TpuProgramShardingType type); - -// Gets TPU executable proto from a `tpu_program`. -TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( - const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, - SE_Status* status); - -// Gets compilation metadata proto from a `tpu_program`. -TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( - const XLA_TpuProgram* tpu_program, - CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status); - - -// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. -TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( - TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, - SE_Status* status); - -struct TfTpu_TpuProgramApiFn { - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); -}; - -} // extern "C" - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc index abc53cfc0eb..ad194cf6531 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.cc +++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc @@ -20,10 +20,9 @@ limitations under the License. #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/status_helper.h" @@ -39,7 +38,7 @@ TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo( VLOG(1) << "ConstructExecutableInfo"; TpuSerializedProto serialized_executable_info = {}; StatusHelper status; - TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn( + OpsApiFn()->TpuProgram_GetExecutableInfoFn( xla_tpu_program, &serialized_executable_info, status.c_status); TPUExecutableInfoProto executable_info; if (status.ok()) { @@ -55,7 +54,7 @@ TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo( VLOG(1) << "ConstructHostTransferInfo"; TpuSerializedProto serialized_host_transfer_info = {}; StatusHelper status; - TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn( + OpsApiFn()->TpuProgram_GetHostTransferInfoFn( xla_tpu_program, &serialized_host_transfer_info, status.c_status); TPUHostTransferInfoProto host_transfer_info; if (status.ok()) { @@ -71,7 +70,7 @@ xla::HloProto TpuProgramGroup::ConstructHloMetadata( VLOG(1) << "ConstructHloMetadata"; TpuSerializedProto serialized_hlo_metadata = {}; StatusHelper status; - TpuProgramApiFn()->TpuProgram_GetHloMetadataFn( + OpsApiFn()->TpuProgram_GetHloMetadataFn( xla_tpu_program, &serialized_hlo_metadata, status.c_status); xla::HloProto hlo_metadata; if (status.ok()) { @@ -97,8 +96,8 @@ void TpuProgramGroup::Initialize( for (size_t i = 0; i < tpu_programs_.size(); ++i) { const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i]; bool may_modify_variables; - TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn( - xla_tpu_program, &may_modify_variables); + OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program, + &may_modify_variables); may_modify_variables_array[i] = may_modify_variables; executable_infos[i] = ConstructExecutableInfo(xla_tpu_program); host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program); @@ -114,7 +113,7 @@ void TpuProgramGroup::Initialize( bool TpuProgramGroup::has_sharding_program() const { for (const XLA_TpuProgram* tpu_program : tpu_programs_) { - if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) { + if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) { return false; } } @@ -126,7 +125,7 @@ size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); } int64_t TpuProgramGroup::program_size() const { int64_t total_size = 0; for (const XLA_TpuProgram* tpu_program : tpu_programs_) { - total_size += TpuProgramApiFn()->TpuProgram_GetProgramSizeFn(tpu_program); + total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program); } return total_size; } @@ -134,8 +133,7 @@ int64_t TpuProgramGroup::program_size() const { bool TpuProgramGroup::LogProgramMemorySummary() { bool success = true; for (const XLA_TpuProgram* tpu_program : tpu_programs_) { - success &= - TpuProgramApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program); + success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program); } return success; } @@ -143,8 +141,7 @@ bool TpuProgramGroup::LogProgramMemorySummary() { void TpuProgramGroup::UnloadAndDestroyPrograms() { for (XLA_TpuProgram* tpu_program : tpu_programs_) { StatusHelper status; - TpuProgramApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, - status.c_status); + OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status); auto s = status.status(); if (!s.ok()) { LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString(); @@ -208,8 +205,8 @@ bool TpuProgramGroup::may_modify_variables(int index) const { CHECK_GE(index, 0); CHECK_LT(index, tpu_programs_.size()); bool may_modify_variables; - TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index], - &may_modify_variables); + OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index], + &may_modify_variables); return may_modify_variables; } @@ -258,9 +255,9 @@ Status TpuProgramGroup::CompileAndBuild( size_t count = 0; XLA_TpuProgram** xla_tpu_programs = nullptr; StatusHelper status; - CompileApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request, - mesh_state, &xla_tpu_programs, - &count, status.c_status); + OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request, + mesh_state, &xla_tpu_programs, + &count, status.c_status); if (!status.ok()) { VLOG(1) << "Run CompileAndBuild failed."; return status.status(); @@ -275,7 +272,7 @@ Status TpuProgramGroup::CompileAndBuild( tensorflow::down_cast(tpu_program_group_interface); tpu_program_group->Initialize( absl::MakeConstSpan(&xla_tpu_programs[0], count)); - TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs); + OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs); return status.status(); } @@ -284,8 +281,8 @@ std::vector TpuProgramGroup::tpu_programs( std::vector tpu_programs; tpu_programs.reserve(tpu_programs_.size()); for (size_t i = 0; i < tpu_programs_.size(); ++i) { - if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) { - tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn( + if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) { + tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn( tpu_programs_[i], sharding_type)); CHECK_NE(tpu_programs[i], nullptr); } @@ -300,11 +297,11 @@ Status TpuProgramGroup::DeserializeFromRpcResponseProtos( for (size_t i = 0; i < rpc_response_protos.size(); ++i) { StatusHelper status; - auto* xla_tpu_program = TpuProgramApiFn()->TpuProgram_NewFn(); - TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn( + auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn(); + OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn( rpc_response_protos[i], xla_tpu_program, status.c_status); if (!status.status().ok()) { - TpuProgramApiFn()->TpuProgram_FreeFn(xla_tpu_program); + OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program); return status.status(); } tpu_programs[i] = xla_tpu_program; @@ -319,8 +316,8 @@ Status TpuProgramGroup::SerializeExecutable( CHECK_GE(index, 0); CHECK_LT(index, tpu_programs_.size()); StatusHelper status; - TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn( - tpu_programs_[index], executable, status.c_status); + OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index], + executable, status.c_status); return status.status(); } @@ -329,7 +326,7 @@ Status TpuProgramGroup::SerializeCompilerMetadata( CHECK_GE(index, 0); CHECK_LT(index, tpu_programs_.size()); StatusHelper status; - TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn( + OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn( tpu_programs_[index], compiler_metadata, status.c_status); return status.status(); } diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h index 3ed1623e9e6..5812976d0d3 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -27,10 +27,9 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" namespace tensorflow { diff --git a/tensorflow/core/tpu/kernels/tpu_util_c_api.h b/tensorflow/core/tpu/kernels/tpu_util_c_api.h deleted file mode 100644 index 04b65e24e54..00000000000 --- a/tensorflow/core/tpu/kernels/tpu_util_c_api.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_ -#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_ - -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" -#include "tensorflow/core/tpu/libtftpu.h" -#include "tensorflow/stream_executor/tpu/proto_helper.h" - -// Property for creating compilation cache key. -struct CompilationCacheKeyProperty { - const char* config_prefix; - const char* shapes_prefix; - const char* function_name; - const char* mlir_module; - const int32_t* device_ids; - size_t device_ids_size; - int32_t guaranteed_constants_size; - uint64_t function_library_fingerprint; - int32_t num_cores_per_replica; - int32_t num_replicas; - const XLA_TpuMeshState* mesh_state; -}; - -// Compilation cache key result returning both the key and a more verbose debug -// version. -struct CompilationCacheKeyResult { - const char* key; - const char* debug_string; -}; - -extern "C" { - -// Checks if whether a TPU compilation is enabled. -TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled(); - -// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit -// when cancellation is requested for an XLA compile op. Some tests require this -// behavior to be disabled, and we test for this condition with the following -// flag function. -TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); - -// Returns the number of available TPU core count. -TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( - const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); - -// Recycle unused service port. -TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); - -// Creates a unique compilation cache `key` used for `put` and `get` operations. -// Returned buffers are heap-allocated and must be owned. -TFTPU_CAPI_EXPORT CompilationCacheKeyResult -TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property); - -// Destroys the CompilationCacheKeyResult returned by calling the -// `TpuCompile_CreateCompilationCacheKey` API. -TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey( - CompilationCacheKeyResult result); - -// Creates a guaranteed const fingerprint. Guarantee const is normally used in -// TPU inference to avoid re-copying unchanged variables onto the TPU device. -// It promises the value is identical for every execution in the same session -// even if the actual value changes in later executions. -TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( - uint64_t fingerprint, const char* data, size_t size); - -} // extern "C" - -struct TfTpu_UtilApiFn { - TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); - TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); - TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); - TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); - TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); - TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); - TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); -}; - -#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_ diff --git a/tensorflow/core/tpu/tpu_api.cc b/tensorflow/core/tpu/tpu_api.cc index 17520ea6ea4..690e2049652 100644 --- a/tensorflow/core/tpu/tpu_api.cc +++ b/tensorflow/core/tpu/tpu_api.cc @@ -23,39 +23,9 @@ TfTpu_BaseFn* InitializeApiFn() { return &base_fn; } -TfTpu_ConfigApiFn* ConfigApiFn() { - static TfTpu_ConfigApiFn config_api_fn; - return &config_api_fn; -} - -TfTpu_MeshStateApiFn* MeshStateApiFn() { - static TfTpu_MeshStateApiFn mesh_state_api_fn; - return &mesh_state_api_fn; -} - -TfTpu_CompileApiFn* CompileApiFn() { - static TfTpu_CompileApiFn compile_api_fn; - return &compile_api_fn; -} - -TfTpu_ExecuteApiFn* ExecuteApiFn() { - static TfTpu_ExecuteApiFn execute_api_fn; - return &execute_api_fn; -} - -TfTpu_TpuProgramApiFn* TpuProgramApiFn() { - static TfTpu_TpuProgramApiFn tpu_program_api_fn; - return &tpu_program_api_fn; -} - -TfTpu_NodeContextApiFn* NodeContextApiFn() { - static TfTpu_NodeContextApiFn node_context_api_fn; - return &node_context_api_fn; -} - -TfTpu_UtilApiFn* UtilApiFn() { - static TfTpu_UtilApiFn util_api_fn; - return &util_api_fn; +TfTpu_OpsApiFn* OpsApiFn() { + static TfTpu_OpsApiFn ops_api_fn; + return &ops_api_fn; } } // namespace tpu diff --git a/tensorflow/core/tpu/tpu_api.h b/tensorflow/core/tpu/tpu_api.h index a9f7bccfdb4..b880f4ed9cf 100644 --- a/tensorflow/core/tpu/tpu_api.h +++ b/tensorflow/core/tpu/tpu_api.h @@ -16,33 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_TPU_API_H_ #define TENSORFLOW_CORE_TPU_TPU_API_H_ -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/libtftpu.h" -#include "tensorflow/core/tpu/tpu_config_c_api.h" #include "tensorflow/core/tpu/tpu_executor_api.h" -#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" namespace tensorflow { namespace tpu { TfTpu_BaseFn* InitializeApiFn(); -TfTpu_ConfigApiFn* ConfigApiFn(); - -TfTpu_MeshStateApiFn* MeshStateApiFn(); - -TfTpu_CompileApiFn* CompileApiFn(); - -TfTpu_ExecuteApiFn* ExecuteApiFn(); - -TfTpu_TpuProgramApiFn* TpuProgramApiFn(); - -TfTpu_NodeContextApiFn* NodeContextApiFn(); - -TfTpu_UtilApiFn* UtilApiFn(); +TfTpu_OpsApiFn* OpsApiFn(); } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.h b/tensorflow/core/tpu/tpu_api_dlsym_initializer.h index 1126e132264..ffb5ffb33a9 100644 --- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.h +++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.h @@ -17,14 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_TPU_API_DLSYM_INITIALIZER_H_ #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" -#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/libtftpu.h" -#include "tensorflow/core/tpu/tpu_config_c_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" -#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" // LINT.IfChange namespace tensorflow { diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h deleted file mode 100644 index de4b2e25570..00000000000 --- a/tensorflow/core/tpu/tpu_config_c_api.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ -#define TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ - -#include -#include - -#include "tensorflow/c/tf_status.h" -#include "tensorflow/core/tpu/libtftpu.h" - -typedef struct TpuSerializedProto TpuSerializedProto; - -namespace tensorflow { -class TpuMeshCommonState; -} // namespace tensorflow - -extern "C" { - -TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( - const size_t num_cores_per_host_size, const int32_t* num_cores_per_host, - size_t server_address_size, const char* server_address, - size_t* host_config_output_size, char** host_config_output, - TF_Status* status); - -TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( - const size_t num_hosts, const size_t num_cores_per_host, - const int32_t** host_ordinal_to_global_core_id_map, - tensorflow::TpuMeshCommonState* tpu_mesh_common_state, - size_t* tpu_topology_output_size, char** tpu_topology_output, - TF_Status* status); - -TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( - const size_t tpu_host_config_size, const char* tpu_host_config, - const bool enable_whole_mesh_compilations, bool is_master_worker, - size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status); - -TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( - const size_t tpu_topology_size, const char* tpu_topology, - TF_Status* status); - -TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( - int32_t* number_of_chips_output, TF_Status* status); - -TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); -TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); - -TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); - -TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, - TF_Status* status); -TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, - TF_Status* status); - -TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( - int64_t* cache_size_in_bytes); -TFTPU_CAPI_EXPORT -void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( - size_t tpu_host_config_size, const char* tpu_host_config, - size_t* server_address_output_size, char** server_address_output, - TF_Status* status); -TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( - size_t* server_address_output_size, char** server_address_output, - int* port_output, TF_Status* status); -} - -struct TfTpu_ConfigApiFn { - TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); - TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); - TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); - TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); - TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); - TFTPU_ADD_FN_IN_STRUCT( - TpuConfigurationApi_CompilationCacheServerAddressFromConfig); - TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); -}; - -#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index 29a05c0d538..71455936d60 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -107,7 +107,7 @@ xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) { XLA_Shape c_host_shape; XLA_Shape c_device_shape; ApiConverter::ToC(host_shape, &c_host_shape); - tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( + tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( &c_host_shape, &c_device_shape); xla::Shape device_shape = ApiConverter::FromC(&c_device_shape); ApiConverter::Free(&c_host_shape); @@ -119,8 +119,7 @@ int64 ShapeSizeCompact(const xla::Shape& shape) { XLA_Shape c_shape; ApiConverter::ToC(shape, &c_shape); int64 size = - tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactFn( - &c_shape); + tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape); ApiConverter::Free(&c_shape); return size; } @@ -129,7 +128,7 @@ int64 ShapeSizeCompactRaw(const xla::Shape& shape) { XLA_Shape c_shape; ApiConverter::ToC(shape, &c_shape); int64 size = - tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactRawFn( + tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn( &c_shape); ApiConverter::Free(&c_shape); return size; @@ -241,11 +240,10 @@ xla::Status UpdateDynamicInputs( ApiConverter::ToC(runtime_shape, &c_runtime_shape); ApiConverter::ToC(compile_time_shape, &c_compile_time_shape); StatusHelper status; - tensorflow::tpu::ExecuteApiFn() - ->TpuExecute_RuntimeInputToPaddedDataFn( - raw_input_runtime->data(), raw_input_runtime->size(), - padded_data->data(), padded_data->size(), &c_runtime_shape, - &c_compile_time_shape, status.c_status); + tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn( + raw_input_runtime->data(), raw_input_runtime->size(), + padded_data->data(), padded_data->size(), &c_runtime_shape, + &c_compile_time_shape, status.c_status); ApiConverter::Free(&c_runtime_shape); ApiConverter::Free(&c_compile_time_shape); return status.status(); diff --git a/tensorflow/core/tpu/tpu_execute.h b/tensorflow/core/tpu/tpu_execute.h index e2142ad7a7a..fc247eb2e7d 100644 --- a/tensorflow/core/tpu/tpu_execute.h +++ b/tensorflow/core/tpu/tpu_execute.h @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/tpu/tpu_node_context.h" diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index f824c9202e5..dc9714a8918 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -6,118 +6,76 @@ namespace { -tensorflow::Status SetTpuConfigStructFns(void* library_handle) { - auto* config_fn = tensorflow::tpu::ConfigApiFn(); +tensorflow::Status SetTpuOpsStructFns(void* library_handle) { + auto* ops_api_fn = tensorflow::tpu::OpsApiFn(); - TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork); - TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork); - TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork); - TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork); - TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork); - TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray); - TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array); - TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState); - TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost); - TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit); - TFTPU_SET_FN(config_fn, + TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork); + TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork); + TFTPU_SET_FN(ops_api_fn, InitializeHostForDistributedTpuOp_DoWork); + TFTPU_SET_FN(ops_api_fn, SetGlobalTPUArrayOp_DoWork); + TFTPU_SET_FN(ops_api_fn, DisconnectDistributedTpuChipsOp_DoWork); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeCharArray); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeInt32Array); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_HasTPUPodState); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpusPerHost); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpuMemoryLimit); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); - TFTPU_SET_FN(config_fn, + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_CompilationCacheServerAddressFromConfig); - TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort); + TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_GetServerAddressAndPort); - return tensorflow::Status::OK(); -} + TFTPU_SET_FN(ops_api_fn, TpuMeshState_Create); + TFTPU_SET_FN(ops_api_fn, TpuMeshState_Free); + TFTPU_SET_FN(ops_api_fn, TpuMeshState_MeshCommonState); -tensorflow::Status SetTpuMeshStateStructFns(void* library_handle) { - auto* mesh_state_fn = tensorflow::tpu::MeshStateApiFn(); + TFTPU_SET_FN(ops_api_fn, TpuCompile_CompileAndBuild); - TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Create); - TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Free); - TFTPU_SET_FN(mesh_state_fn, TpuMeshState_MeshCommonState); + TFTPU_SET_FN(ops_api_fn, TpuExecutable_LoadProgramAndEnqueueToStream); + TFTPU_SET_FN(ops_api_fn, HardwareLayout_HostShapeToDeviceShape); + TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSize); + TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompact); + TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompactRaw); + TFTPU_SET_FN(ops_api_fn, TpuExecute_RuntimeInputToPaddedData); - return tensorflow::Status::OK(); -} - -tensorflow::Status SetCompileStructFn(void* library_handle) { - auto* compile_fn = tensorflow::tpu::CompileApiFn(); - - TFTPU_SET_FN(compile_fn, TpuCompile_CompileAndBuild); - - return tensorflow::Status::OK(); -} - -tensorflow::Status SetExecuteStructFn(void* library_handle) { - auto* execute_fn = tensorflow::tpu::ExecuteApiFn(); - - TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream); - TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape); - TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize); - TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompact); - TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompactRaw); - TFTPU_SET_FN(execute_fn, TpuExecute_RuntimeInputToPaddedData); - - return tensorflow::Status::OK(); -} - -tensorflow::Status SetTpuProgramStructFn(void* library_handle) { - auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn(); - - TFTPU_SET_FN(tpu_program_fn, TpuProgram_New); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_Free); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_NewArray); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_FreeArray); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_UnloadAndDestroy); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetProgramSize); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_LogProgramMemorySummary); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetExecutableInfo); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata); - TFTPU_SET_FN(tpu_program_fn, + TFTPU_SET_FN(ops_api_fn, TpuProgram_New); + TFTPU_SET_FN(ops_api_fn, TpuProgram_Free); + TFTPU_SET_FN(ops_api_fn, TpuProgram_NewArray); + TFTPU_SET_FN(ops_api_fn, TpuProgram_FreeArray); + TFTPU_SET_FN(ops_api_fn, TpuProgram_UnloadAndDestroy); + TFTPU_SET_FN(ops_api_fn, TpuProgram_GetProgramSize); + TFTPU_SET_FN(ops_api_fn, TpuProgram_LogProgramMemorySummary); + TFTPU_SET_FN(ops_api_fn, TpuProgram_GetExecutableInfo); + TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHostTransferInfo); + TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHloMetadata); + TFTPU_SET_FN(ops_api_fn, TpuProgram_GetMayModifyVariables); + TFTPU_SET_FN(ops_api_fn, TpuProgram_HasSharding); + TFTPU_SET_FN(ops_api_fn, TpuProgram_GetTpuProgram); + TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeTpuExecutable); + TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeCompilerMetadata); + TFTPU_SET_FN(ops_api_fn, TpuProgram_DeserializeFromGetTpuProgramResponseProto); - return tensorflow::Status::OK(); -} + TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Create); + TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Free); + TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Initialize); + TFTPU_SET_FN(ops_api_fn, TpuNodeContext_StopChipHeartbeats); + TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CloseTpuHost); -tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) { - auto* node_context_fn = tensorflow::tpu::NodeContextApiFn(); - - TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create); - TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free); - TFTPU_SET_FN(node_context_fn, TpuNodeContext_Initialize); - TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats); - TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost); - - return tensorflow::Status::OK(); -} - -tensorflow::Status SetTpuUtilStructFns(void* library_handle) { - auto* util_fn = tensorflow::tpu::UtilApiFn(); - - TFTPU_SET_FN(util_fn, TpuTopology_AvailableCoreCount); - TFTPU_SET_FN(util_fn, TpuNetUtil_RecycleUnusedPort); - TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled); - TFTPU_SET_FN(util_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation); - TFTPU_SET_FN(util_fn, TpuCompile_CreateCompilationCacheKey); - TFTPU_SET_FN(util_fn, TpuCompile_DestroyCompilationCacheKey); - TFTPU_SET_FN(util_fn, TpuCompile_CreateGuaranteedConstFingerprint); + TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount); + TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort); + TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled); + TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation); + TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey); + TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey); + TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint); return tensorflow::Status::OK(); } tensorflow::Status InitializeTpuStructFns(void* library_handle) { - TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle)); - TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle)); - TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle)); - TF_RETURN_IF_ERROR(SetExecuteStructFn(library_handle)); - TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle)); + TF_RETURN_IF_ERROR(SetTpuOpsStructFns(library_handle)); TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle)); - TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle)); - TF_RETURN_IF_ERROR(SetTpuUtilStructFns(library_handle)); return tensorflow::Status::OK(); } diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h new file mode 100644 index 00000000000..1682662a4dc --- /dev/null +++ b/tensorflow/core/tpu/tpu_ops_c_api.h @@ -0,0 +1,342 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ + +#include + +#include "tensorflow/core/tpu/libtftpu.h" +#include "tensorflow/stream_executor/tpu/c_api_decl.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" + +typedef struct TpuSerializedProto TpuSerializedProto; + +namespace tensorflow { +class TpuMeshCommonState; +} // namespace tensorflow + +extern "C" { + +typedef struct XLA_TpuProgram XLA_TpuProgram; + +// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. +enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; + +struct TpuExecutableSerializedProto { + const char* bytes; + size_t size; +}; + +struct CompilerMetadataSerializedProto { + const char* bytes; + size_t size; +}; + +struct HostComputeMetadataSerializedProto { + const char* bytes; + size_t size; +}; + +typedef struct XLA_TpuMeshState XLA_TpuMeshState; + +typedef struct XLA_DeviceAssignment { + const char* bytes; + size_t size; +} XLA_DeviceAssignment; + +// Property for creating compilation cache key. +struct CompilationCacheKeyProperty { + const char* config_prefix; + const char* shapes_prefix; + const char* function_name; + const char* mlir_module; + const int32_t* device_ids; + size_t device_ids_size; + int32_t guaranteed_constants_size; + uint64_t function_library_fingerprint; + int32_t num_cores_per_replica; + int32_t num_replicas; + const XLA_TpuMeshState* mesh_state; +}; + +// Compilation cache key result returning both the key and a more verbose debug +// version. +struct CompilationCacheKeyResult { + const char* key; + const char* debug_string; +}; + +typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; + +// Compiles Mlir or TF function computation by lowering into HLO IR and returns +// `count` number of TPU programs ready for execution. +// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates +// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is +// responsible to deallocate both the `XLA_TpuProgram*[]` array and the +// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` +// API respectively. +TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild( + TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state, + XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status); + +// Creates a new TPU mesh state object. +TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create(); + +// Deletes the given TPU `mesh_state` object. Once deleted the object is +// unusable. +TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); + +// Returns a pointer to an opaque mesh data structure used internally. +TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( + XLA_TpuMeshState* mesh_state); + +TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( + const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments, + size_t arguments_len, SE_DeviceMemoryBase* result, + SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed, + XLA_DeviceAssignment* device_assignment, SE_Stream* stream, + SE_Status* status); + +TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( + XLA_Shape* host_shape, XLA_Shape* device_shape); +TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); +TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); +TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); + +TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( + uint32_t* runtime_input_ptr, size_t runtime_input_size, + int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape, + XLA_Shape* compile_time_shape, SE_Status* status); + +TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( + const size_t num_cores_per_host_size, const int32_t* num_cores_per_host, + size_t server_address_size, const char* server_address, + size_t* host_config_output_size, char** host_config_output, + TF_Status* status); + +TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( + const size_t num_hosts, const size_t num_cores_per_host, + const int32_t** host_ordinal_to_global_core_id_map, + tensorflow::TpuMeshCommonState* tpu_mesh_common_state, + size_t* tpu_topology_output_size, char** tpu_topology_output, + TF_Status* status); + +TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( + const size_t tpu_host_config_size, const char* tpu_host_config, + const bool enable_whole_mesh_compilations, bool is_master_worker, + size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status); + +TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( + const size_t tpu_topology_size, const char* tpu_topology, + TF_Status* status); + +TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( + int32_t* number_of_chips_output, TF_Status* status); + +TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); +TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); + +TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); + +TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, + TF_Status* status); +TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, + TF_Status* status); +TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( + int64_t* cache_size_in_bytes); +TFTPU_CAPI_EXPORT +void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( + size_t tpu_host_config_size, const char* tpu_host_config, + size_t* server_address_output_size, char** server_address_output, + TF_Status* status); +TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( + size_t* server_address_output_size, char** server_address_output, + int* port_output, TF_Status* status); + +// Creates a new TPU program. +TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); + +// Destroys the `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program); + +// Creates an array of `XLA_TpuProgram*`. +TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count); + +// Destroys an array of `XLA_TpuProgram*`. +TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]); + +// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and +// destroyed, it is in an unusable state. +TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, + SE_Status* status); + +// Gets TPU program size in bytes from the `tpu_program`. +TFTPU_CAPI_EXPORT int64_t +TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); + +// Logs the summary of current memory state snapshot of the `tpu_program`. +TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary( + const XLA_TpuProgram* tpu_program); + +// Gets TPU program executable info from the `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo( + const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info, + SE_Status* status); + +// Gets host transfer info proto. +TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo( + const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info, + SE_Status* status); + +// Gets HLO metadata proto. +TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( + const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata, + SE_Status* status); + +// Gets may modify variables boolean value. +TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( + const XLA_TpuProgram* tpu_program, bool* may_modify_variables); + +// Checks if TPU program has sharding. +TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( + const XLA_TpuProgram* tpu_program); + +// Gets TPU program by sharding type. Return value is valid only when the +// `status.status()` returns `OK`. +TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( + XLA_TpuProgram* tpu_program, TpuProgramShardingType type); + +// Gets TPU executable proto from a `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( + const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, + SE_Status* status); + +// Gets compilation metadata proto from a `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( + const XLA_TpuProgram* tpu_program, + CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status); + +// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. +TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( + TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, + SE_Status* status); + +// Checks if whether a TPU compilation is enabled. +TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled(); + +// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit +// when cancellation is requested for an XLA compile op. Some tests require this +// behavior to be disabled, and we test for this condition with the following +// flag function. +TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); + +// Returns the number of available TPU core count. +TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( + const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); + +// Recycle unused service port. +TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); + +// Creates a unique compilation cache `key` used for `put` and `get` operations. +// Returned buffers are heap-allocated and must be owned. +TFTPU_CAPI_EXPORT CompilationCacheKeyResult +TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property); + +// Destroys the CompilationCacheKeyResult returned by calling the +// `TpuCompile_CreateCompilationCacheKey` API. +TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey( + CompilationCacheKeyResult result); + +// Creates a guaranteed const fingerprint. Guarantee const is normally used in +// TPU inference to avoid re-copying unchanged variables onto the TPU device. +// It promises the value is identical for every execution in the same session +// even if the actual value changes in later executions. +TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( + uint64_t fingerprint, const char* data, size_t size); + +XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, + SE_Status* status); +void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); + +void TpuNodeContext_StopChipHeartbeats(SE_Status* status); + +void TpuNodeContext_CloseTpuHost(SE_Status* status); + +void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status); + +struct TfTpu_OpsApiFn { + TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild); + + TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create); + TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free); + TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState); + + TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); + TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); + + TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); + TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); + TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); + TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); + TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); + TFTPU_ADD_FN_IN_STRUCT( + TpuConfigurationApi_CompilationCacheServerAddressFromConfig); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); + + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); + + TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); + TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); + TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); + TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); + TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); + TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); + + TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create); + TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free); + TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats); + TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost); + TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize); +}; + +} // extern "C" + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index 98d59726b60..0ce409eb99a 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -195,7 +195,6 @@ cc_library( deps = [ ":status_helper", ":tpu_executor_c_api_hdrs", - ":tpu_node_context_c_api_hdrs", ":tpu_platform_interface", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", @@ -204,6 +203,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/memory", @@ -293,9 +293,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core/tpu:tpu_api", - "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs", - "//tensorflow/core/tpu/kernels:tpu_program_c_api_hdrs", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", "//tensorflow/stream_executor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", diff --git a/tensorflow/stream_executor/tpu/tpu_executable.cc b/tensorflow/stream_executor/tpu/tpu_executable.cc index 3f7d88392e5..bdae4b9dce5 100644 --- a/tensorflow/stream_executor/tpu/tpu_executable.cc +++ b/tensorflow/stream_executor/tpu/tpu_executable.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h" #include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/status_helper.h" @@ -79,11 +79,10 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream( run_options.run_options().stream()->implementation()); StatusHelper status; - tensorflow::tpu::ExecuteApiFn() - ->TpuExecutable_LoadProgramAndEnqueueToStreamFn( - core_program_, arguments_bases, arguments.size(), &result_base, - (cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr), - rng_seed, &c_dev_assign, stream, status.c_status); + tensorflow::tpu::OpsApiFn()->TpuExecutable_LoadProgramAndEnqueueToStreamFn( + core_program_, arguments_bases, arguments.size(), &result_base, + (cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr), + rng_seed, &c_dev_assign, stream, status.c_status); if (dev_assign != nullptr) { stream_executor::tpu::SerializedProto_Free(dev_assign_serialized); @@ -96,7 +95,7 @@ Shape TpuExecutable::HostShapeToDeviceShape(const Shape& host_shape) { XLA_Shape c_host_shape; XLA_Shape c_device_shape; ApiConverter::ToC(host_shape, &c_host_shape); - tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( + tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( &c_host_shape, &c_device_shape); Shape device_shape = ApiConverter::FromC(&c_device_shape); ApiConverter::Free(&c_host_shape); @@ -108,7 +107,7 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) { XLA_Shape c_shape; ApiConverter::ToC(shape, &c_shape); int64 size = - tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeFn(&c_shape); + tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeFn(&c_shape); ApiConverter::Free(&c_shape); return size; } diff --git a/tensorflow/stream_executor/tpu/tpu_executable.h b/tensorflow/stream_executor/tpu/tpu_executable.h index d2c3200c93d..0785d66b83a 100644 --- a/tensorflow/stream_executor/tpu/tpu_executable.h +++ b/tensorflow/stream_executor/tpu/tpu_executable.h @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h" diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.cc b/tensorflow/stream_executor/tpu/tpu_node_context.cc index b5597e2f88f..13447a74d40 100644 --- a/tensorflow/stream_executor/tpu/tpu_node_context.cc +++ b/tensorflow/stream_executor/tpu/tpu_node_context.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_node_context.h" #include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" -#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" namespace tensorflow { namespace tpu { @@ -30,40 +30,38 @@ StatusOr> TpuNodeContext::Create( int device_ordinal) { StatusHelper status; XLA_TpuNodeContext* node_context = - tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal, - status.c_status); + tpu::OpsApiFn()->TpuNodeContext_CreateFn(device_ordinal, status.c_status); if (!status.status().ok()) { // TpuNodeContext_CreateFn allocates a new XLA_TpuNodeContext regardless of // status. It needs to be freed if it's not given to a TpuNodeContext below. - tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context); + tpu::OpsApiFn()->TpuNodeContext_FreeFn(node_context); return status.status(); } return std::make_unique(device_ordinal, node_context); } TpuNodeContext::~TpuNodeContext() { - tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_); + tpu::OpsApiFn()->TpuNodeContext_FreeFn(node_context_); } /* static */ Status TpuNodeContext::StopChipHeartbeats() { StatusHelper status; - tpu::NodeContextApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status); + tpu::OpsApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status); return status.status(); } /* static */ Status TpuNodeContext::CloseTpuHost() { StatusHelper status; - tpu::NodeContextApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status); + tpu::OpsApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status); return status.status(); } /* static */ Status TpuNodeContext::Initialize(int device_ordinal) { StatusHelper status; - tpu::NodeContextApiFn()->TpuNodeContext_InitializeFn(device_ordinal, - status.c_status); + tpu::OpsApiFn()->TpuNodeContext_InitializeFn(device_ordinal, status.c_status); return status.status(); } diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.h b/tensorflow/stream_executor/tpu/tpu_node_context.h index 27cf32f854f..48b3c25bf10 100644 --- a/tensorflow/stream_executor/tpu/tpu_node_context.h +++ b/tensorflow/stream_executor/tpu/tpu_node_context.h @@ -24,11 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" #include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/tpu/status_helper.h" -#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" namespace tensorflow { diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc index 5a01848e78b..41a26644483 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform.cc +++ b/tensorflow/stream_executor/tpu/tpu_platform.cc @@ -147,7 +147,7 @@ void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) { Status TpuPlatform::TpusPerHost(int* tpus) { TF_Status* status = TF_NewStatus(); - tpu::ConfigApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status); + tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status); auto ret_status = StatusFromTF_Status(status); TF_DeleteStatus(status); return ret_status; @@ -155,7 +155,7 @@ Status TpuPlatform::TpusPerHost(int* tpus) { Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) { TF_Status* status = TF_NewStatus(); - tpu::ConfigApiFn()->TpuConfigurationApi_TpuMemoryLimitFn( + tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn( reinterpret_cast(memory_limit), status); auto ret_status = StatusFromTF_Status(status); TF_DeleteStatus(status);