[TPU 1VM] Consolidate all TPU ops related APIs into a single file
PiperOrigin-RevId: 338777305 Change-Id: I84fa82f0efeef7f5e64896018da244908ecef11a
This commit is contained in:
parent
df7d1daff4
commit
35d27a0335
@ -110,30 +110,15 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_config_c_api",
|
||||
hdrs = ["tpu_config_c_api.h"],
|
||||
deps = [
|
||||
":libtftpu_header",
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_api",
|
||||
srcs = ["tpu_api.cc"],
|
||||
hdrs = ["tpu_api.h"],
|
||||
deps = [
|
||||
":libtftpu_header",
|
||||
":tpu_config_c_api",
|
||||
":tpu_executor_api",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
":tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
@ -160,20 +145,15 @@ cc_library(
|
||||
":tpu_api",
|
||||
":tpu_api_dlsym_set_fn",
|
||||
":tpu_compilation_device",
|
||||
":tpu_config_c_api",
|
||||
":tpu_executor_init_fns",
|
||||
":tpu_library_init_fns",
|
||||
":tpu_node_device",
|
||||
":tpu_ops_c_api_hdrs",
|
||||
":tpu_system_device",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
@ -292,7 +272,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
@ -351,3 +331,16 @@ cc_library(
|
||||
"//tensorflow/stream_executor/tpu:tpu_transfer_manager",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_ops_c_api_hdrs",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_ops_c_api.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -158,7 +158,7 @@ cc_library(
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_compile_interface",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_topology_external",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
|
@ -64,9 +64,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_compile_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
@ -56,7 +56,7 @@ cc_library(
|
||||
":tpu_op_util",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_util",
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
":tpu_util_hdrs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -114,7 +114,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
@ -123,19 +123,6 @@ tf_kernel_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_c_api_hdrs",
|
||||
hdrs = ["tpu_compile_c_api.h"],
|
||||
deps = [
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "tpu_executable_info_proto",
|
||||
srcs = ["tpu_executable_info.proto"],
|
||||
@ -261,24 +248,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_mesh_state_c_api_hdrs",
|
||||
hdrs = ["tpu_mesh_state_c_api.h"],
|
||||
deps = ["//tensorflow/core/tpu:libtftpu_header"],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_mesh_state_interface",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_mesh_state_interface.h"],
|
||||
deps = [
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
@ -310,14 +289,11 @@ cc_library(
|
||||
srcs = ["tpu_program_group.cc"],
|
||||
hdrs = ["tpu_program_group.h"],
|
||||
deps = [
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_op_common",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_compile_proto_cc",
|
||||
":tpu_executable_info_proto_cc",
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_program_group_interface",
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -329,6 +305,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
@ -382,11 +359,9 @@ cc_library(
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compilation_metrics", # buildcleaner: keep
|
||||
":tpu_compilation_metrics_hdrs",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_op_consts",
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_program_group",
|
||||
":tpu_util",
|
||||
":trace_util_hdrs",
|
||||
@ -398,6 +373,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -450,42 +426,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_util_c_api_hdrs",
|
||||
hdrs = ["tpu_util_c_api.h"],
|
||||
deps = [
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_program_c_api_hdrs",
|
||||
hdrs = ["tpu_program_c_api.h"],
|
||||
deps = [
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_op_util",
|
||||
srcs = ["tpu_op_util.cc"],
|
||||
hdrs = ["tpu_op_util.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_mesh_state_interface",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -496,7 +448,6 @@ cc_library(
|
||||
hdrs = ["tpu_util.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -504,6 +455,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
tf_grpc_cc_dependency(),
|
||||
@ -548,7 +500,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
)
|
||||
@ -665,17 +617,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_execute_c_api_hdrs",
|
||||
hdrs = ["tpu_execute_c_api.h"],
|
||||
deps = [
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_impl",
|
||||
srcs = ["tpu_compile_op_impl.cc"],
|
||||
@ -683,11 +624,9 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_op_common",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_compile_proto_cc",
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
":tpu_program_group",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_util",
|
||||
@ -696,6 +635,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
|
@ -25,11 +25,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
@ -34,11 +34,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
@ -434,7 +434,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
|
||||
// Check if caller has disabled compilation. Set using
|
||||
// internal::ScopedTpuCompileDisabler.
|
||||
if (!UtilApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
|
||||
if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
|
||||
const std::string error_msg = strings::StrCat(
|
||||
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
||||
"disabled, session_name(",
|
||||
|
@ -1,44 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Compiles Mlir or TF function computation by lowering into HLO IR and returns
|
||||
// `count` number of TPU programs ready for execution.
|
||||
// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
|
||||
// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
|
||||
// responsible to deallocate both the `XLA_TpuProgram*[]` array and the
|
||||
// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
|
||||
// API respectively.
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
|
||||
TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
|
||||
XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status);
|
||||
|
||||
struct TfTpu_CompileApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
@ -39,10 +39,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_util.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -543,7 +543,7 @@ void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
|
||||
ctx->cancellation_manager()->get_cancellation_token();
|
||||
const bool already_cancelled =
|
||||
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
|
||||
if (UtilApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
|
||||
if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
@ -32,9 +32,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -203,12 +203,11 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
tpu_topology_output);
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
|
||||
});
|
||||
|
||||
auto* mesh_common_state = mesh_state->mesh_common_state();
|
||||
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||
tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||
num_hosts, num_devices_per_host,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
|
||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||
@ -247,7 +246,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
|
||||
|
||||
bool is_master_worker =
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
|
||||
if (!is_master_worker) {
|
||||
// Reset the mesh interface if we are not the master.
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
@ -283,9 +282,9 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
int32_t* device_id_output = nullptr;
|
||||
auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||
});
|
||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
|
||||
&device_id_output, status);
|
||||
@ -302,16 +301,16 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
tpu::kCompiledProtoCacheResourceName, proto_lookup));
|
||||
} else {
|
||||
int64_t cache_size_bytes;
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
|
||||
&cache_size_bytes);
|
||||
|
||||
char* server_address_output = nullptr;
|
||||
auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
server_address_output);
|
||||
});
|
||||
size_t server_address_output_size;
|
||||
tpu::ConfigApiFn()
|
||||
tpu::OpsApiFn()
|
||||
->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
&server_address_output_size, &server_address_output, status);
|
||||
@ -346,8 +345,8 @@ void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
|
||||
auto tpu_topology = ctx->input(0).scalar<tstring>()();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
tpu::ConfigApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
|
||||
tpu_topology.data(), status);
|
||||
tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
|
||||
tpu_topology.data(), status);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
@ -362,7 +361,7 @@ void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
int32_t number_of_chips_output = 0;
|
||||
|
||||
tpu::ConfigApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
|
||||
tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
|
||||
&number_of_chips_output, status);
|
||||
|
||||
Tensor* ctx_output;
|
||||
|
@ -1,59 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
typedef struct XLA_DeviceAssignment {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
} XLA_DeviceAssignment;
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
|
||||
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
|
||||
size_t arguments_len, SE_DeviceMemoryBase* result,
|
||||
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
|
||||
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
|
||||
SE_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
|
||||
XLA_Shape* host_shape, XLA_Shape* device_shape);
|
||||
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
|
||||
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
|
||||
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
|
||||
uint32_t* runtime_input_ptr, size_t runtime_input_size,
|
||||
int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape,
|
||||
XLA_Shape* compile_time_shape, SE_Status* status);
|
||||
|
||||
struct TfTpu_ExecuteApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
|
||||
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Creates a new TPU mesh state object.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
|
||||
|
||||
// Deletes the given TPU `mesh_state` object. Once deleted the object is
|
||||
// unusable.
|
||||
TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
||||
|
||||
// Returns a pointer to an opaque mesh data structure used internally.
|
||||
TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
|
||||
XLA_TpuMeshState* mesh_state);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
struct TfTpu_MeshStateApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
@ -19,9 +19,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -39,19 +38,19 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
||||
|
||||
~TpuMeshStateInterface() override {
|
||||
if (mesh_state_ != nullptr) {
|
||||
MeshStateApiFn()->TpuMeshState_FreeFn(mesh_state_);
|
||||
OpsApiFn()->TpuMeshState_FreeFn(mesh_state_);
|
||||
}
|
||||
}
|
||||
|
||||
static TpuMeshStateInterface* Create() {
|
||||
return new TpuMeshStateInterface(MeshStateApiFn()->TpuMeshState_CreateFn());
|
||||
return new TpuMeshStateInterface(OpsApiFn()->TpuMeshState_CreateFn());
|
||||
}
|
||||
|
||||
const XLA_TpuMeshState* data() const { return mesh_state_; }
|
||||
|
||||
tensorflow::TpuMeshCommonState* mesh_common_state() const {
|
||||
return static_cast<tensorflow::TpuMeshCommonState*>(
|
||||
MeshStateApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
|
||||
OpsApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
|
||||
}
|
||||
|
||||
// Returns whether we should include the device assignment as a static field
|
||||
@ -63,8 +62,8 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
||||
// Static device assignment enables XLA to perform certain optimization when
|
||||
// all cores are used in the replicated computation.
|
||||
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
|
||||
UtilApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
|
||||
tpu_core_type);
|
||||
OpsApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
|
||||
tpu_core_type);
|
||||
}
|
||||
|
||||
string DebugString() const override { return "TpuMeshStateInterface"; }
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -77,7 +77,7 @@ std::string GuaranteedConstFingerprint(
|
||||
uint64_t fingerprint = 0;
|
||||
for (const Tensor& constant : guaranteed_constants) {
|
||||
fingerprint =
|
||||
tpu::UtilApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
|
||||
tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
|
||||
fingerprint, constant.tensor_data().data(),
|
||||
constant.tensor_data().size());
|
||||
}
|
||||
@ -110,7 +110,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
}
|
||||
}
|
||||
CompilationCacheKeyResult result =
|
||||
tpu::UtilApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
|
||||
tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
|
||||
CompilationCacheKeyProperty{
|
||||
config_prefix.data(),
|
||||
shapes_prefix.data(),
|
||||
@ -125,7 +125,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
mesh_state.data(),
|
||||
});
|
||||
auto buffer_cleanup = gtl::MakeCleanup([result]() {
|
||||
tpu::UtilApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
|
||||
tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
|
||||
});
|
||||
TpuCompilationCacheKey key;
|
||||
key.prefix = result.key;
|
||||
|
@ -74,12 +74,11 @@ Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
|
||||
char* server_address_output = nullptr;
|
||||
auto cleanup = xla::MakeCleanup([&status, &server_address_output]() {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
|
||||
server_address_output);
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(server_address_output);
|
||||
});
|
||||
size_t server_address_output_size;
|
||||
*serving_port = -1;
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
|
||||
&server_address_output_size, &server_address_output, serving_port,
|
||||
status);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
|
||||
@ -98,7 +97,7 @@ TpuPodState::~TpuPodState() {
|
||||
VLOG(1) << "Shutting down Compilation Cache Service.";
|
||||
if (cache_service_->Shutdown(20)) {
|
||||
if (service_port_ >= 0) {
|
||||
tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
|
||||
tpu::OpsApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
@ -150,10 +149,10 @@ Status ConstructTpuPodState(
|
||||
|
||||
char* host_config_output = nullptr;
|
||||
auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() {
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||
});
|
||||
size_t host_config_output_size;
|
||||
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
server_address.size(), server_address.data(), &host_config_output_size,
|
||||
&host_config_output, status);
|
||||
|
@ -1,135 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
|
||||
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
||||
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
||||
|
||||
struct TpuExecutableSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct CompilerMetadataSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct HostComputeMetadataSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Creates a new TPU program.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
|
||||
|
||||
// Destroys the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Creates an array of `XLA_TpuProgram*`.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
|
||||
|
||||
// Destroys an array of `XLA_TpuProgram*`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
|
||||
|
||||
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
|
||||
// destroyed, it is in an unusable state.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets TPU program size in bytes from the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT int64_t
|
||||
TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Logs the summary of current memory state snapshot of the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
|
||||
const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Gets TPU program executable info from the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets host transfer info proto.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets HLO metadata proto.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets may modify variables boolean value.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
||||
|
||||
// Checks if TPU program has sharding.
|
||||
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
||||
const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Gets TPU program by sharding type. Return value is valid only when the
|
||||
// `status.status()` returns `OK`.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
||||
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
||||
|
||||
// Gets TPU executable proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
|
||||
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets compilation metadata proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
|
||||
const XLA_TpuProgram* tpu_program,
|
||||
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
|
||||
|
||||
|
||||
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
|
||||
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
|
||||
struct TfTpu_TpuProgramApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
|
@ -20,10 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
|
||||
@ -39,7 +38,7 @@ TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
|
||||
VLOG(1) << "ConstructExecutableInfo";
|
||||
TpuSerializedProto serialized_executable_info = {};
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_GetExecutableInfoFn(
|
||||
OpsApiFn()->TpuProgram_GetExecutableInfoFn(
|
||||
xla_tpu_program, &serialized_executable_info, status.c_status);
|
||||
TPUExecutableInfoProto executable_info;
|
||||
if (status.ok()) {
|
||||
@ -55,7 +54,7 @@ TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
|
||||
VLOG(1) << "ConstructHostTransferInfo";
|
||||
TpuSerializedProto serialized_host_transfer_info = {};
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_GetHostTransferInfoFn(
|
||||
OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
|
||||
xla_tpu_program, &serialized_host_transfer_info, status.c_status);
|
||||
TPUHostTransferInfoProto host_transfer_info;
|
||||
if (status.ok()) {
|
||||
@ -71,7 +70,7 @@ xla::HloProto TpuProgramGroup::ConstructHloMetadata(
|
||||
VLOG(1) << "ConstructHloMetadata";
|
||||
TpuSerializedProto serialized_hlo_metadata = {};
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_GetHloMetadataFn(
|
||||
OpsApiFn()->TpuProgram_GetHloMetadataFn(
|
||||
xla_tpu_program, &serialized_hlo_metadata, status.c_status);
|
||||
xla::HloProto hlo_metadata;
|
||||
if (status.ok()) {
|
||||
@ -97,8 +96,8 @@ void TpuProgramGroup::Initialize(
|
||||
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
|
||||
const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
|
||||
bool may_modify_variables;
|
||||
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(
|
||||
xla_tpu_program, &may_modify_variables);
|
||||
OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
|
||||
&may_modify_variables);
|
||||
may_modify_variables_array[i] = may_modify_variables;
|
||||
executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
|
||||
host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
|
||||
@ -114,7 +113,7 @@ void TpuProgramGroup::Initialize(
|
||||
|
||||
bool TpuProgramGroup::has_sharding_program() const {
|
||||
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
if (!TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
|
||||
if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -126,7 +125,7 @@ size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
|
||||
int64_t TpuProgramGroup::program_size() const {
|
||||
int64_t total_size = 0;
|
||||
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
total_size += TpuProgramApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
|
||||
total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
|
||||
}
|
||||
return total_size;
|
||||
}
|
||||
@ -134,8 +133,7 @@ int64_t TpuProgramGroup::program_size() const {
|
||||
bool TpuProgramGroup::LogProgramMemorySummary() {
|
||||
bool success = true;
|
||||
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
success &=
|
||||
TpuProgramApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
|
||||
success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
|
||||
}
|
||||
return success;
|
||||
}
|
||||
@ -143,8 +141,7 @@ bool TpuProgramGroup::LogProgramMemorySummary() {
|
||||
void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
||||
for (XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program,
|
||||
status.c_status);
|
||||
OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
|
||||
auto s = status.status();
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
|
||||
@ -208,8 +205,8 @@ bool TpuProgramGroup::may_modify_variables(int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
bool may_modify_variables;
|
||||
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
|
||||
&may_modify_variables);
|
||||
OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
|
||||
&may_modify_variables);
|
||||
return may_modify_variables;
|
||||
}
|
||||
|
||||
@ -258,9 +255,9 @@ Status TpuProgramGroup::CompileAndBuild(
|
||||
size_t count = 0;
|
||||
XLA_TpuProgram** xla_tpu_programs = nullptr;
|
||||
StatusHelper status;
|
||||
CompileApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
|
||||
mesh_state, &xla_tpu_programs,
|
||||
&count, status.c_status);
|
||||
OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
|
||||
mesh_state, &xla_tpu_programs,
|
||||
&count, status.c_status);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Run CompileAndBuild failed.";
|
||||
return status.status();
|
||||
@ -275,7 +272,7 @@ Status TpuProgramGroup::CompileAndBuild(
|
||||
tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
|
||||
tpu_program_group->Initialize(
|
||||
absl::MakeConstSpan(&xla_tpu_programs[0], count));
|
||||
TpuProgramApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
|
||||
OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
@ -284,8 +281,8 @@ std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
|
||||
std::vector<XLA_TpuProgram*> tpu_programs;
|
||||
tpu_programs.reserve(tpu_programs_.size());
|
||||
for (size_t i = 0; i < tpu_programs_.size(); ++i) {
|
||||
if (TpuProgramApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
|
||||
tpu_programs.push_back(TpuProgramApiFn()->TpuProgram_GetTpuProgramFn(
|
||||
if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
|
||||
tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
|
||||
tpu_programs_[i], sharding_type));
|
||||
CHECK_NE(tpu_programs[i], nullptr);
|
||||
}
|
||||
@ -300,11 +297,11 @@ Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
|
||||
|
||||
for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
|
||||
StatusHelper status;
|
||||
auto* xla_tpu_program = TpuProgramApiFn()->TpuProgram_NewFn();
|
||||
TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
|
||||
auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
|
||||
OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
|
||||
rpc_response_protos[i], xla_tpu_program, status.c_status);
|
||||
if (!status.status().ok()) {
|
||||
TpuProgramApiFn()->TpuProgram_FreeFn(xla_tpu_program);
|
||||
OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
|
||||
return status.status();
|
||||
}
|
||||
tpu_programs[i] = xla_tpu_program;
|
||||
@ -319,8 +316,8 @@ Status TpuProgramGroup::SerializeExecutable(
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn(
|
||||
tpu_programs_[index], executable, status.c_status);
|
||||
OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
|
||||
executable, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
@ -329,7 +326,7 @@ Status TpuProgramGroup::SerializeCompilerMetadata(
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn(
|
||||
OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
|
||||
tpu_programs_[index], compiler_metadata, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
@ -27,10 +27,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -1,91 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
// Property for creating compilation cache key.
|
||||
struct CompilationCacheKeyProperty {
|
||||
const char* config_prefix;
|
||||
const char* shapes_prefix;
|
||||
const char* function_name;
|
||||
const char* mlir_module;
|
||||
const int32_t* device_ids;
|
||||
size_t device_ids_size;
|
||||
int32_t guaranteed_constants_size;
|
||||
uint64_t function_library_fingerprint;
|
||||
int32_t num_cores_per_replica;
|
||||
int32_t num_replicas;
|
||||
const XLA_TpuMeshState* mesh_state;
|
||||
};
|
||||
|
||||
// Compilation cache key result returning both the key and a more verbose debug
|
||||
// version.
|
||||
struct CompilationCacheKeyResult {
|
||||
const char* key;
|
||||
const char* debug_string;
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Checks if whether a TPU compilation is enabled.
|
||||
TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
|
||||
|
||||
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
|
||||
// when cancellation is requested for an XLA compile op. Some tests require this
|
||||
// behavior to be disabled, and we test for this condition with the following
|
||||
// flag function.
|
||||
TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
|
||||
|
||||
// Returns the number of available TPU core count.
|
||||
TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
|
||||
const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
|
||||
|
||||
// Recycle unused service port.
|
||||
TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
|
||||
|
||||
// Creates a unique compilation cache `key` used for `put` and `get` operations.
|
||||
// Returned buffers are heap-allocated and must be owned.
|
||||
TFTPU_CAPI_EXPORT CompilationCacheKeyResult
|
||||
TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
|
||||
|
||||
// Destroys the CompilationCacheKeyResult returned by calling the
|
||||
// `TpuCompile_CreateCompilationCacheKey` API.
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
|
||||
CompilationCacheKeyResult result);
|
||||
|
||||
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
|
||||
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
|
||||
// It promises the value is identical for every execution in the same session
|
||||
// even if the actual value changes in later executions.
|
||||
TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
|
||||
uint64_t fingerprint, const char* data, size_t size);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
struct TfTpu_UtilApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
@ -23,39 +23,9 @@ TfTpu_BaseFn* InitializeApiFn() {
|
||||
return &base_fn;
|
||||
}
|
||||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn() {
|
||||
static TfTpu_ConfigApiFn config_api_fn;
|
||||
return &config_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_MeshStateApiFn* MeshStateApiFn() {
|
||||
static TfTpu_MeshStateApiFn mesh_state_api_fn;
|
||||
return &mesh_state_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_CompileApiFn* CompileApiFn() {
|
||||
static TfTpu_CompileApiFn compile_api_fn;
|
||||
return &compile_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_ExecuteApiFn* ExecuteApiFn() {
|
||||
static TfTpu_ExecuteApiFn execute_api_fn;
|
||||
return &execute_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_TpuProgramApiFn* TpuProgramApiFn() {
|
||||
static TfTpu_TpuProgramApiFn tpu_program_api_fn;
|
||||
return &tpu_program_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_NodeContextApiFn* NodeContextApiFn() {
|
||||
static TfTpu_NodeContextApiFn node_context_api_fn;
|
||||
return &node_context_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_UtilApiFn* UtilApiFn() {
|
||||
static TfTpu_UtilApiFn util_api_fn;
|
||||
return &util_api_fn;
|
||||
TfTpu_OpsApiFn* OpsApiFn() {
|
||||
static TfTpu_OpsApiFn ops_api_fn;
|
||||
return &ops_api_fn;
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
|
@ -16,33 +16,16 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_TPU_TPU_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_TPU_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_executor_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
TfTpu_BaseFn* InitializeApiFn();
|
||||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn();
|
||||
|
||||
TfTpu_MeshStateApiFn* MeshStateApiFn();
|
||||
|
||||
TfTpu_CompileApiFn* CompileApiFn();
|
||||
|
||||
TfTpu_ExecuteApiFn* ExecuteApiFn();
|
||||
|
||||
TfTpu_TpuProgramApiFn* TpuProgramApiFn();
|
||||
|
||||
TfTpu_NodeContextApiFn* NodeContextApiFn();
|
||||
|
||||
TfTpu_UtilApiFn* UtilApiFn();
|
||||
TfTpu_OpsApiFn* OpsApiFn();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -17,14 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_TPU_TPU_API_DLSYM_INITIALIZER_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
|
||||
// LINT.IfChange
|
||||
namespace tensorflow {
|
||||
|
@ -1,97 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
|
||||
typedef struct TpuSerializedProto TpuSerializedProto;
|
||||
|
||||
namespace tensorflow {
|
||||
class TpuMeshCommonState;
|
||||
} // namespace tensorflow
|
||||
|
||||
extern "C" {
|
||||
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
size_t server_address_size, const char* server_address,
|
||||
size_t* host_config_output_size, char** host_config_output,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
const int32_t** host_ordinal_to_global_core_id_map,
|
||||
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
|
||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
||||
const size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
const bool enable_whole_mesh_compilations, bool is_master_worker,
|
||||
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
||||
const size_t tpu_topology_size, const char* tpu_topology,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
|
||||
int32_t* number_of_chips_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
||||
|
||||
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
|
||||
int64_t* cache_size_in_bytes);
|
||||
TFTPU_CAPI_EXPORT
|
||||
void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
|
||||
size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
size_t* server_address_output_size, char** server_address_output,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
|
||||
size_t* server_address_output_size, char** server_address_output,
|
||||
int* port_output, TF_Status* status);
|
||||
}
|
||||
|
||||
struct TfTpu_ConfigApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
|
||||
TFTPU_ADD_FN_IN_STRUCT(
|
||||
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
@ -107,7 +107,7 @@ xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
|
||||
XLA_Shape c_host_shape;
|
||||
XLA_Shape c_device_shape;
|
||||
ApiConverter::ToC(host_shape, &c_host_shape);
|
||||
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
|
||||
tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
|
||||
&c_host_shape, &c_device_shape);
|
||||
xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
|
||||
ApiConverter::Free(&c_host_shape);
|
||||
@ -119,8 +119,7 @@ int64 ShapeSizeCompact(const xla::Shape& shape) {
|
||||
XLA_Shape c_shape;
|
||||
ApiConverter::ToC(shape, &c_shape);
|
||||
int64 size =
|
||||
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactFn(
|
||||
&c_shape);
|
||||
tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape);
|
||||
ApiConverter::Free(&c_shape);
|
||||
return size;
|
||||
}
|
||||
@ -129,7 +128,7 @@ int64 ShapeSizeCompactRaw(const xla::Shape& shape) {
|
||||
XLA_Shape c_shape;
|
||||
ApiConverter::ToC(shape, &c_shape);
|
||||
int64 size =
|
||||
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
|
||||
tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
|
||||
&c_shape);
|
||||
ApiConverter::Free(&c_shape);
|
||||
return size;
|
||||
@ -241,11 +240,10 @@ xla::Status UpdateDynamicInputs(
|
||||
ApiConverter::ToC(runtime_shape, &c_runtime_shape);
|
||||
ApiConverter::ToC(compile_time_shape, &c_compile_time_shape);
|
||||
StatusHelper status;
|
||||
tensorflow::tpu::ExecuteApiFn()
|
||||
->TpuExecute_RuntimeInputToPaddedDataFn(
|
||||
raw_input_runtime->data(), raw_input_runtime->size(),
|
||||
padded_data->data(), padded_data->size(), &c_runtime_shape,
|
||||
&c_compile_time_shape, status.c_status);
|
||||
tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn(
|
||||
raw_input_runtime->data(), raw_input_runtime->size(),
|
||||
padded_data->data(), padded_data->size(), &c_runtime_shape,
|
||||
&c_compile_time_shape, status.c_status);
|
||||
ApiConverter::Free(&c_runtime_shape);
|
||||
ApiConverter::Free(&c_compile_time_shape);
|
||||
return status.status();
|
||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
|
@ -6,118 +6,76 @@
|
||||
|
||||
namespace {
|
||||
|
||||
tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
||||
auto* config_fn = tensorflow::tpu::ConfigApiFn();
|
||||
tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
|
||||
auto* ops_api_fn = tensorflow::tpu::OpsApiFn();
|
||||
|
||||
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit);
|
||||
TFTPU_SET_FN(config_fn,
|
||||
TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(ops_api_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(ops_api_fn, SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_SET_FN(ops_api_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_FreeInt32Array);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_HasTPUPodState);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpusPerHost);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_TpuMemoryLimit);
|
||||
TFTPU_SET_FN(ops_api_fn,
|
||||
TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
|
||||
TFTPU_SET_FN(config_fn,
|
||||
TFTPU_SET_FN(ops_api_fn,
|
||||
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuConfigurationApi_GetServerAddressAndPort);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
TFTPU_SET_FN(ops_api_fn, TpuMeshState_Create);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuMeshState_Free);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuMeshState_MeshCommonState);
|
||||
|
||||
tensorflow::Status SetTpuMeshStateStructFns(void* library_handle) {
|
||||
auto* mesh_state_fn = tensorflow::tpu::MeshStateApiFn();
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_CompileAndBuild);
|
||||
|
||||
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Create);
|
||||
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Free);
|
||||
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_MeshCommonState);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
|
||||
TFTPU_SET_FN(ops_api_fn, HardwareLayout_HostShapeToDeviceShape);
|
||||
TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSize);
|
||||
TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompact);
|
||||
TFTPU_SET_FN(ops_api_fn, HardwareLayout_ShapeSizeCompactRaw);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuExecute_RuntimeInputToPaddedData);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetCompileStructFn(void* library_handle) {
|
||||
auto* compile_fn = tensorflow::tpu::CompileApiFn();
|
||||
|
||||
TFTPU_SET_FN(compile_fn, TpuCompile_CompileAndBuild);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetExecuteStructFn(void* library_handle) {
|
||||
auto* execute_fn = tensorflow::tpu::ExecuteApiFn();
|
||||
|
||||
TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
|
||||
TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape);
|
||||
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize);
|
||||
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompact);
|
||||
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSizeCompactRaw);
|
||||
TFTPU_SET_FN(execute_fn, TpuExecute_RuntimeInputToPaddedData);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
||||
auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn();
|
||||
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_New);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_Free);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_NewArray);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_FreeArray);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_UnloadAndDestroy);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetProgramSize);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_LogProgramMemorySummary);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetExecutableInfo);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHostTransferInfo);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetHloMetadata);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata);
|
||||
TFTPU_SET_FN(tpu_program_fn,
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_New);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_Free);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_NewArray);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_FreeArray);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_UnloadAndDestroy);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetProgramSize);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_LogProgramMemorySummary);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetExecutableInfo);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHostTransferInfo);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetHloMetadata);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_HasSharding);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_GetTpuProgram);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeTpuExecutable);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProgram_SerializeCompilerMetadata);
|
||||
TFTPU_SET_FN(ops_api_fn,
|
||||
TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Create);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Free);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_Initialize);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuNodeContext_CloseTpuHost);
|
||||
|
||||
tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) {
|
||||
auto* node_context_fn = tensorflow::tpu::NodeContextApiFn();
|
||||
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Initialize);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetTpuUtilStructFns(void* library_handle) {
|
||||
auto* util_fn = tensorflow::tpu::UtilApiFn();
|
||||
|
||||
TFTPU_SET_FN(util_fn, TpuTopology_AvailableCoreCount);
|
||||
TFTPU_SET_FN(util_fn, TpuNetUtil_RecycleUnusedPort);
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled);
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_DestroyCompilationCacheKey);
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuTopology_AvailableCoreCount);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuNetUtil_RecycleUnusedPort);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_IsTpuCompilationEnabled);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status InitializeTpuStructFns(void* library_handle) {
|
||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetExecuteStructFn(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuOpsStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuUtilStructFns(library_handle));
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
342
tensorflow/core/tpu/tpu_ops_c_api.h
Normal file
342
tensorflow/core/tpu/tpu_ops_c_api.h
Normal file
@ -0,0 +1,342 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
typedef struct TpuSerializedProto TpuSerializedProto;
|
||||
|
||||
namespace tensorflow {
|
||||
class TpuMeshCommonState;
|
||||
} // namespace tensorflow
|
||||
|
||||
extern "C" {
|
||||
|
||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
|
||||
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
||||
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
||||
|
||||
struct TpuExecutableSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct CompilerMetadataSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct HostComputeMetadataSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
|
||||
|
||||
typedef struct XLA_DeviceAssignment {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
} XLA_DeviceAssignment;
|
||||
|
||||
// Property for creating compilation cache key.
|
||||
struct CompilationCacheKeyProperty {
|
||||
const char* config_prefix;
|
||||
const char* shapes_prefix;
|
||||
const char* function_name;
|
||||
const char* mlir_module;
|
||||
const int32_t* device_ids;
|
||||
size_t device_ids_size;
|
||||
int32_t guaranteed_constants_size;
|
||||
uint64_t function_library_fingerprint;
|
||||
int32_t num_cores_per_replica;
|
||||
int32_t num_replicas;
|
||||
const XLA_TpuMeshState* mesh_state;
|
||||
};
|
||||
|
||||
// Compilation cache key result returning both the key and a more verbose debug
|
||||
// version.
|
||||
struct CompilationCacheKeyResult {
|
||||
const char* key;
|
||||
const char* debug_string;
|
||||
};
|
||||
|
||||
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
|
||||
|
||||
// Compiles Mlir or TF function computation by lowering into HLO IR and returns
|
||||
// `count` number of TPU programs ready for execution.
|
||||
// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
|
||||
// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
|
||||
// responsible to deallocate both the `XLA_TpuProgram*[]` array and the
|
||||
// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
|
||||
// API respectively.
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
|
||||
TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
|
||||
XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status);
|
||||
|
||||
// Creates a new TPU mesh state object.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
|
||||
|
||||
// Deletes the given TPU `mesh_state` object. Once deleted the object is
|
||||
// unusable.
|
||||
TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
||||
|
||||
// Returns a pointer to an opaque mesh data structure used internally.
|
||||
TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
|
||||
XLA_TpuMeshState* mesh_state);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
|
||||
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
|
||||
size_t arguments_len, SE_DeviceMemoryBase* result,
|
||||
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
|
||||
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
|
||||
SE_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
|
||||
XLA_Shape* host_shape, XLA_Shape* device_shape);
|
||||
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
|
||||
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
|
||||
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
|
||||
uint32_t* runtime_input_ptr, size_t runtime_input_size,
|
||||
int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape,
|
||||
XLA_Shape* compile_time_shape, SE_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
size_t server_address_size, const char* server_address,
|
||||
size_t* host_config_output_size, char** host_config_output,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
const int32_t** host_ordinal_to_global_core_id_map,
|
||||
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
|
||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
||||
const size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
const bool enable_whole_mesh_compilations, bool is_master_worker,
|
||||
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
||||
const size_t tpu_topology_size, const char* tpu_topology,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
|
||||
int32_t* number_of_chips_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
||||
|
||||
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
|
||||
int64_t* cache_size_in_bytes);
|
||||
TFTPU_CAPI_EXPORT
|
||||
void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
|
||||
size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
size_t* server_address_output_size, char** server_address_output,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
|
||||
size_t* server_address_output_size, char** server_address_output,
|
||||
int* port_output, TF_Status* status);
|
||||
|
||||
// Creates a new TPU program.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
|
||||
|
||||
// Destroys the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Creates an array of `XLA_TpuProgram*`.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
|
||||
|
||||
// Destroys an array of `XLA_TpuProgram*`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
|
||||
|
||||
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
|
||||
// destroyed, it is in an unusable state.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets TPU program size in bytes from the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT int64_t
|
||||
TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Logs the summary of current memory state snapshot of the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
|
||||
const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Gets TPU program executable info from the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets host transfer info proto.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets HLO metadata proto.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets may modify variables boolean value.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
||||
|
||||
// Checks if TPU program has sharding.
|
||||
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
||||
const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Gets TPU program by sharding type. Return value is valid only when the
|
||||
// `status.status()` returns `OK`.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
||||
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
||||
|
||||
// Gets TPU executable proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
|
||||
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets compilation metadata proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
|
||||
const XLA_TpuProgram* tpu_program,
|
||||
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
|
||||
|
||||
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
|
||||
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
|
||||
// Checks if whether a TPU compilation is enabled.
|
||||
TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
|
||||
|
||||
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
|
||||
// when cancellation is requested for an XLA compile op. Some tests require this
|
||||
// behavior to be disabled, and we test for this condition with the following
|
||||
// flag function.
|
||||
TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
|
||||
|
||||
// Returns the number of available TPU core count.
|
||||
TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
|
||||
const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
|
||||
|
||||
// Recycle unused service port.
|
||||
TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
|
||||
|
||||
// Creates a unique compilation cache `key` used for `put` and `get` operations.
|
||||
// Returned buffers are heap-allocated and must be owned.
|
||||
TFTPU_CAPI_EXPORT CompilationCacheKeyResult
|
||||
TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
|
||||
|
||||
// Destroys the CompilationCacheKeyResult returned by calling the
|
||||
// `TpuCompile_CreateCompilationCacheKey` API.
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
|
||||
CompilationCacheKeyResult result);
|
||||
|
||||
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
|
||||
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
|
||||
// It promises the value is identical for every execution in the same session
|
||||
// even if the actual value changes in later executions.
|
||||
TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
|
||||
uint64_t fingerprint, const char* data, size_t size);
|
||||
|
||||
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
||||
SE_Status* status);
|
||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||
|
||||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||
|
||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||
|
||||
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
||||
|
||||
struct TfTpu_OpsApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
|
||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
|
||||
TFTPU_ADD_FN_IN_STRUCT(
|
||||
TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
|
@ -195,7 +195,6 @@ cc_library(
|
||||
deps = [
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_node_context_c_api_hdrs",
|
||||
":tpu_platform_interface",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
@ -204,6 +203,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -293,9 +293,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_program_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
@ -79,11 +79,10 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
|
||||
run_options.run_options().stream()->implementation());
|
||||
StatusHelper status;
|
||||
|
||||
tensorflow::tpu::ExecuteApiFn()
|
||||
->TpuExecutable_LoadProgramAndEnqueueToStreamFn(
|
||||
core_program_, arguments_bases, arguments.size(), &result_base,
|
||||
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
|
||||
rng_seed, &c_dev_assign, stream, status.c_status);
|
||||
tensorflow::tpu::OpsApiFn()->TpuExecutable_LoadProgramAndEnqueueToStreamFn(
|
||||
core_program_, arguments_bases, arguments.size(), &result_base,
|
||||
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
|
||||
rng_seed, &c_dev_assign, stream, status.c_status);
|
||||
|
||||
if (dev_assign != nullptr) {
|
||||
stream_executor::tpu::SerializedProto_Free(dev_assign_serialized);
|
||||
@ -96,7 +95,7 @@ Shape TpuExecutable::HostShapeToDeviceShape(const Shape& host_shape) {
|
||||
XLA_Shape c_host_shape;
|
||||
XLA_Shape c_device_shape;
|
||||
ApiConverter::ToC(host_shape, &c_host_shape);
|
||||
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
|
||||
tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
|
||||
&c_host_shape, &c_device_shape);
|
||||
Shape device_shape = ApiConverter::FromC(&c_device_shape);
|
||||
ApiConverter::Free(&c_host_shape);
|
||||
@ -108,7 +107,7 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
|
||||
XLA_Shape c_shape;
|
||||
ApiConverter::ToC(shape, &c_shape);
|
||||
int64 size =
|
||||
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
|
||||
tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
|
||||
ApiConverter::Free(&c_shape);
|
||||
return size;
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
|
||||
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -30,40 +30,38 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
|
||||
int device_ordinal) {
|
||||
StatusHelper status;
|
||||
XLA_TpuNodeContext* node_context =
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
|
||||
status.c_status);
|
||||
tpu::OpsApiFn()->TpuNodeContext_CreateFn(device_ordinal, status.c_status);
|
||||
if (!status.status().ok()) {
|
||||
// TpuNodeContext_CreateFn allocates a new XLA_TpuNodeContext regardless of
|
||||
// status. It needs to be freed if it's not given to a TpuNodeContext below.
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
|
||||
tpu::OpsApiFn()->TpuNodeContext_FreeFn(node_context);
|
||||
return status.status();
|
||||
}
|
||||
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
|
||||
}
|
||||
|
||||
TpuNodeContext::~TpuNodeContext() {
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
|
||||
tpu::OpsApiFn()->TpuNodeContext_FreeFn(node_context_);
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::StopChipHeartbeats() {
|
||||
StatusHelper status;
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
|
||||
tpu::OpsApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::CloseTpuHost() {
|
||||
StatusHelper status;
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
|
||||
tpu::OpsApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||
StatusHelper status;
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_InitializeFn(device_ordinal,
|
||||
status.c_status);
|
||||
tpu::OpsApiFn()->TpuNodeContext_InitializeFn(device_ordinal, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
|
@ -24,11 +24,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -147,7 +147,7 @@ void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) {
|
||||
|
||||
Status TpuPlatform::TpusPerHost(int* tpus) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
|
||||
auto ret_status = StatusFromTF_Status(status);
|
||||
TF_DeleteStatus(status);
|
||||
return ret_status;
|
||||
@ -155,7 +155,7 @@ Status TpuPlatform::TpusPerHost(int* tpus) {
|
||||
|
||||
Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
|
||||
reinterpret_cast<int64_t*>(memory_limit), status);
|
||||
auto ret_status = StatusFromTF_Status(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
Loading…
Reference in New Issue
Block a user