Add TPU configuration ops to default TensorFlow build
PiperOrigin-RevId: 317133514 Change-Id: I33bc6d7fdbba5915bd0d1291d4e086139c07eb14
This commit is contained in:
parent
d1157c976b
commit
cf00e559d7
4
.bazelrc
4
.bazelrc
|
@ -39,6 +39,7 @@
|
||||||
#
|
#
|
||||||
# Feature and Third party library support options:
|
# Feature and Third party library support options:
|
||||||
# xla: Build TF with XLA
|
# xla: Build TF with XLA
|
||||||
|
# tpu: Build TF with TPU support
|
||||||
# using_cuda: CUDA is available to build system.
|
# using_cuda: CUDA is available to build system.
|
||||||
# cuda: Build with full cuda support.
|
# cuda: Build with full cuda support.
|
||||||
# rocm: Build with AMD GPU support (rocm).
|
# rocm: Build with AMD GPU support (rocm).
|
||||||
|
@ -180,6 +181,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||||
build:dbg --copt -DDEBUG_BUILD
|
build:dbg --copt -DDEBUG_BUILD
|
||||||
|
|
||||||
|
# Config to build TPU backend
|
||||||
|
build:tpu --define=with_tpu_support=true
|
||||||
|
|
||||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||||
|
|
||||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||||
|
|
|
@ -467,6 +467,13 @@ config_setting(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This flag enables experimental TPU support
|
||||||
|
config_setting(
|
||||||
|
name = "with_tpu_support",
|
||||||
|
values = {"define": "with_tpu_support=true"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
# Specifies via a config setting if this is a mobile build or not, makes
|
# Specifies via a config setting if this is a mobile build or not, makes
|
||||||
# it easier to combine settings later.
|
# it easier to combine settings later.
|
||||||
selects.config_setting_group(
|
selects.config_setting_group(
|
||||||
|
|
|
@ -72,6 +72,7 @@ load(
|
||||||
"if_ios",
|
"if_ios",
|
||||||
"if_mobile",
|
"if_mobile",
|
||||||
"if_not_windows",
|
"if_not_windows",
|
||||||
|
"if_tpu",
|
||||||
"tf_android_core_proto_headers",
|
"tf_android_core_proto_headers",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_cc_test_mkl",
|
"tf_cc_test_mkl",
|
||||||
|
@ -1093,6 +1094,8 @@ cc_library(
|
||||||
]) + if_tensorrt([
|
]) + if_tensorrt([
|
||||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
|
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
|
||||||
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
|
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
|
||||||
|
]) + if_tpu([
|
||||||
|
"//tensorflow/core/tpu/kernels",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -103,6 +103,7 @@ cc_library(
|
||||||
":libtftpu_header",
|
":libtftpu_header",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
],
|
],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -116,7 +117,20 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":libtftpu_header",
|
":libtftpu_header",
|
||||||
":tpu_config_c_api",
|
":tpu_config_c_api",
|
||||||
|
":tpu_library_init_fns",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_platform_hdrs",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_library_init_fns",
|
||||||
|
hdrs = ["tpu_library_init_fns.inc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
|
@ -3,6 +3,10 @@ load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
"tf_proto_library_cc",
|
"tf_proto_library_cc",
|
||||||
)
|
)
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_kernel_library",
|
||||||
|
)
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
|
@ -12,6 +16,12 @@ package(
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "kernels",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":tpu_configuration_ops"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compile_op_common",
|
name = "tpu_compile_op_common",
|
||||||
srcs = ["tpu_compile_op_common.cc"],
|
srcs = ["tpu_compile_op_common.cc"],
|
||||||
|
@ -50,7 +60,7 @@ cc_library(
|
||||||
hdrs = ["tpu_compile_op_options.h"],
|
hdrs = ["tpu_compile_op_options.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
tf_kernel_library(
|
||||||
name = "tpu_configuration_ops",
|
name = "tpu_configuration_ops",
|
||||||
srcs = ["tpu_configuration_ops.cc"],
|
srcs = ["tpu_configuration_ops.cc"],
|
||||||
hdrs = ["tpu_configuration_ops.h"],
|
hdrs = ["tpu_configuration_ops.h"],
|
||||||
|
@ -75,12 +85,13 @@ cc_library(
|
||||||
name = "tpu_compile_c_api_hdrs",
|
name = "tpu_compile_c_api_hdrs",
|
||||||
hdrs = ["tpu_compile_c_api.h"],
|
hdrs = ["tpu_compile_c_api.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_mesh_state_c_api",
|
":tpu_mesh_state_c_api_hdrs",
|
||||||
":tpu_ops_common_c_api_hdrs",
|
":tpu_ops_common_c_api_hdrs",
|
||||||
":tpu_program_c_api_hdrs",
|
":tpu_program_c_api_hdrs",
|
||||||
"//tensorflow/c:tf_datatype",
|
"//tensorflow/core/tpu:libtftpu_header",
|
||||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
],
|
],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_proto_library_cc(
|
tf_proto_library_cc(
|
||||||
|
@ -197,8 +208,10 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_mesh_state_c_api",
|
name = "tpu_mesh_state_c_api_hdrs",
|
||||||
hdrs = ["tpu_mesh_state_c_api.h"],
|
hdrs = ["tpu_mesh_state_c_api.h"],
|
||||||
|
deps = ["//tensorflow/core/tpu:libtftpu_header"],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -207,12 +220,11 @@ cc_library(
|
||||||
hdrs = ["tpu_mesh_state_interface.h"],
|
hdrs = ["tpu_mesh_state_interface.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_compile_c_api_hdrs",
|
":tpu_compile_c_api_hdrs",
|
||||||
":tpu_mesh_state_c_api",
|
":tpu_mesh_state_c_api_hdrs",
|
||||||
"//tensorflow/compiler/xla/service",
|
"//tensorflow/compiler/xla/service",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/platform:errors",
|
|
||||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -371,13 +383,16 @@ cc_library(
|
||||||
name = "tpu_util_c_api_hdrs",
|
name = "tpu_util_c_api_hdrs",
|
||||||
hdrs = ["tpu_util_c_api.h"],
|
hdrs = ["tpu_util_c_api.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core/tpu:libtftpu_header",
|
||||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
],
|
],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_ops_common_c_api_hdrs",
|
name = "tpu_ops_common_c_api_hdrs",
|
||||||
hdrs = ["tpu_ops_common_c_api.h"],
|
hdrs = ["tpu_ops_common_c_api.h"],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -387,6 +402,7 @@ cc_library(
|
||||||
":tpu_ops_common_c_api_hdrs",
|
":tpu_ops_common_c_api_hdrs",
|
||||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
],
|
],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
|
||||||
enum TpuCoreTypeEnum {
|
enum TpuCoreTypeEnum {
|
||||||
|
@ -44,35 +45,41 @@ struct CompilationCacheKeyProperty {
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
// Returns the number of available TPU core count.
|
// Returns the number of available TPU core count.
|
||||||
int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state,
|
TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
|
||||||
TpuCoreTypeEnum tpu_core_type);
|
const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
|
||||||
|
|
||||||
// Creates a unique compilation cache `key` used for `put` and `get` operations.
|
// Creates a unique compilation cache `key` used for `put` and `get` operations.
|
||||||
// Returned buffer is heap-allocated and must be owned.
|
// Returned buffer is heap-allocated and must be owned.
|
||||||
const char* TpuCompile_CreateCompilationCacheKey(
|
TFTPU_CAPI_EXPORT const char* TpuCompile_CreateCompilationCacheKey(
|
||||||
CompilationCacheKeyProperty property);
|
CompilationCacheKeyProperty property);
|
||||||
|
|
||||||
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
|
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
|
||||||
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
|
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
|
||||||
// It promises the value is identical for every execution in the same session
|
// It promises the value is identical for every execution in the same session
|
||||||
// even if the actual value changes in later executions.
|
// even if the actual value changes in later executions.
|
||||||
uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint,
|
TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
|
||||||
const char* data,
|
uint64_t fingerprint, const char* data, size_t size);
|
||||||
size_t size);
|
|
||||||
|
|
||||||
// Executes the computations using XLA TPU compiler and returns TPU programs
|
// Executes the computations using XLA TPU compiler and returns TPU programs
|
||||||
// ready for execution.
|
// ready for execution.
|
||||||
void TpuCompile_CompileAheadOfTime(
|
TFTPU_CAPI_EXPORT void TpuCompile_CompileAheadOfTime(
|
||||||
TpuSerializedProto aot_compilation_request,
|
TpuSerializedProto aot_compilation_request, XLA_TpuProgram** tpu_programs[],
|
||||||
XLA_TpuProgram** tpu_programs[],
|
|
||||||
size_t* count, SE_Status* status);
|
size_t* count, SE_Status* status);
|
||||||
|
|
||||||
// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
|
// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
|
||||||
void TpuCompile_BuildXLADeviceAssignment(
|
TFTPU_CAPI_EXPORT void TpuCompile_BuildXLADeviceAssignment(
|
||||||
TpuSerializedProto serialized_tpu_compile_metadata,
|
TpuSerializedProto serialized_tpu_compile_metadata,
|
||||||
const XLA_TpuMeshState* mesh_state,
|
const XLA_TpuMeshState* mesh_state,
|
||||||
TpuSerializedProto* serialized_device_assignment, SE_Status* status);
|
TpuSerializedProto* serialized_device_assignment, SE_Status* status);
|
||||||
|
|
||||||
|
struct TfTpu_CompileApiFn {
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAheadOfTime);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_BuildXLADeviceAssignment);
|
||||||
|
};
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||||
|
|
|
@ -353,7 +353,7 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogAndExit(42);
|
std::quick_exit(42);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
|
/* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -97,13 +98,14 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||||
rmgr, tpu::kTpuMeshCommonStateResourceName));
|
rmgr, tpu::kTpuMeshCommonStateResourceName));
|
||||||
|
|
||||||
ConfigureDistributedTpuOp_DoWork(
|
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||||
&host_config_output_size, &host_config_output, status);
|
&host_config_output_size, &host_config_output, status);
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
|
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
|
||||||
tpu::kTpuMeshCommonStateResourceName,
|
OP_REQUIRES_OK(ctx,
|
||||||
tpu::TpuMeshStateInterface::Create()));
|
rmgr->Create(rmgr->default_container(),
|
||||||
|
tpu::kTpuMeshCommonStateResourceName, tpu_mesh));
|
||||||
|
|
||||||
Tensor* ctx_output;
|
Tensor* ctx_output;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||||
|
@ -112,7 +114,8 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
TpuConfigurationApi_FreeCharArray(host_config_output);
|
|
||||||
|
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||||
|
|
||||||
VLOG(1) << "ConfigureDistributedTpuOp done";
|
VLOG(1) << "ConfigureDistributedTpuOp done";
|
||||||
}
|
}
|
||||||
|
@ -171,7 +174,7 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
||||||
core::ScopedUnref mesh_state_unref(mesh_state);
|
core::ScopedUnref mesh_state_unref(mesh_state);
|
||||||
|
|
||||||
WaitForDistributedTpuOp_DoWork(
|
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||||
num_hosts, num_devices_per_host,
|
num_hosts, num_devices_per_host,
|
||||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
|
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
|
||||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||||
|
@ -183,7 +186,7 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
TpuConfigurationApi_FreeCharArray(tpu_topology_output);
|
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
|
||||||
|
|
||||||
VLOG(1) << "WaitForDistributedTpuOp done";
|
VLOG(1) << "WaitForDistributedTpuOp done";
|
||||||
}
|
}
|
||||||
|
@ -196,7 +199,7 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||||
GetTPUConfigResourceMgr(),
|
GetTPUConfigResourceMgr(),
|
||||||
tpu::kTpuMeshCommonStateResourceName));
|
tpu::kTpuMeshCommonStateResourceName));
|
||||||
ShutdownDistributedTpuOp_DoWork(status);
|
tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status);
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
@ -213,7 +216,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
int32_t* device_id_output;
|
int32_t* device_id_output;
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
InitializeHostForDistributedTpuOp_DoWork(
|
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||||
tpu_host_config.size(), tpu_host_config.data(),
|
tpu_host_config.size(), tpu_host_config.data(),
|
||||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||||
&device_id_output, status);
|
&device_id_output, status);
|
||||||
|
@ -230,7 +233,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
TpuConfigurationApi_FreeInt32Array(device_id_output);
|
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||||
|
|
||||||
VLOG(1) << "InitializeHostForDistributedTpuOp done";
|
VLOG(1) << "InitializeHostForDistributedTpuOp done";
|
||||||
}
|
}
|
||||||
|
@ -242,7 +245,8 @@ void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
|
||||||
auto tpu_topology = ctx->input(0).scalar<tstring>()();
|
auto tpu_topology = ctx->input(0).scalar<tstring>()();
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status);
|
tpu::ConfigApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
|
||||||
|
tpu_topology.data(), status);
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
|
@ -257,7 +261,8 @@ void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
int32_t number_of_chips_output = 0;
|
int32_t number_of_chips_output = 0;
|
||||||
|
|
||||||
DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status);
|
tpu::ConfigApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
|
||||||
|
&number_of_chips_output, status);
|
||||||
|
|
||||||
Tensor* ctx_output;
|
Tensor* ctx_output;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||||
|
|
|
@ -15,20 +15,29 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
|
|
||||||
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
|
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
// Creates a new TPU mesh state object.
|
// Creates a new TPU mesh state object.
|
||||||
XLA_TpuMeshState* TpuMeshState_Create();
|
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
|
||||||
|
|
||||||
// Deletes the given TPU `mesh_state` object. Once deleted the object is
|
// Deletes the given TPU `mesh_state` object. Once deleted the object is
|
||||||
// unusable.
|
// unusable.
|
||||||
void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
||||||
|
|
||||||
// Returns a pointer to an opaque mesh data structure used internally.
|
// Returns a pointer to an opaque mesh data structure used internally.
|
||||||
void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state);
|
TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
|
||||||
|
XLA_TpuMeshState* mesh_state);
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
|
struct TfTpu_MeshStateApiFn {
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -38,19 +39,19 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
||||||
|
|
||||||
~TpuMeshStateInterface() override {
|
~TpuMeshStateInterface() override {
|
||||||
if (mesh_state_ != nullptr) {
|
if (mesh_state_ != nullptr) {
|
||||||
TpuMeshState_Free(mesh_state_);
|
MeshStateApiFn()->TpuMeshState_FreeFn(mesh_state_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static TpuMeshStateInterface* Create() {
|
static TpuMeshStateInterface* Create() {
|
||||||
return new TpuMeshStateInterface(TpuMeshState_Create());
|
return new TpuMeshStateInterface(MeshStateApiFn()->TpuMeshState_CreateFn());
|
||||||
}
|
}
|
||||||
|
|
||||||
const XLA_TpuMeshState* data() const { return mesh_state_; }
|
const XLA_TpuMeshState* data() const { return mesh_state_; }
|
||||||
|
|
||||||
tensorflow::TpuMeshCommonState* mesh_common_state() const {
|
tensorflow::TpuMeshCommonState* mesh_common_state() const {
|
||||||
return static_cast<tensorflow::TpuMeshCommonState*>(
|
return static_cast<tensorflow::TpuMeshCommonState*>(
|
||||||
TpuMeshState_MeshCommonState(mesh_state_));
|
MeshStateApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether we should include the device assignment as a static field
|
// Returns whether we should include the device assignment as a static field
|
||||||
|
@ -62,8 +63,8 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
||||||
// Static device assignment enables XLA to perform certain optimization when
|
// Static device assignment enables XLA to perform certain optimization when
|
||||||
// all cores are used in the replicated computation.
|
// all cores are used in the replicated computation.
|
||||||
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
|
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
|
||||||
TpuTopology_AvailableCoreCount(mesh_state_,
|
CompileApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
|
||||||
tpu_core_type);
|
tpu_core_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() const override { return "TpuMeshStateInterface"; }
|
string DebugString() const override { return "TpuMeshStateInterface"; }
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
|
||||||
typedef struct SE_Status SE_Status;
|
typedef struct SE_Status SE_Status;
|
||||||
|
@ -32,4 +33,9 @@ void TpuCompile_ToTpuShapeRepresentation(
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
|
struct TfTpu_UtilApiFn {
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ToTpuShapeRepresentation);
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||||
|
|
|
@ -0,0 +1,166 @@
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
||||||
|
auto* config_fn = tensorflow::tpu::ConfigApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
||||||
|
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
||||||
|
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
|
||||||
|
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||||
|
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
||||||
|
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||||
|
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||||
|
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status SetTpuMeshStateStructFns(void* library_handle) {
|
||||||
|
auto* mesh_state_fn = tensorflow::tpu::MeshStateApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Create);
|
||||||
|
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Free);
|
||||||
|
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_MeshCommonState);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status SetCompileStructFn(void* library_handle) {
|
||||||
|
auto* compile_fn = tensorflow::tpu::CompileApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(compile_fn, TpuTopology_AvailableCoreCount);
|
||||||
|
TFTPU_SET_FN(compile_fn, TpuCompile_CreateCompilationCacheKey);
|
||||||
|
TFTPU_SET_FN(compile_fn, TpuCompile_CreateGuaranteedConstFingerprint);
|
||||||
|
TFTPU_SET_FN(compile_fn, TpuCompile_CompileAheadOfTime);
|
||||||
|
TFTPU_SET_FN(compile_fn, TpuCompile_BuildXLADeviceAssignment);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||||
|
auto* executor_fn = tensorflow::tpu::ExecutorApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_Free);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_Initialize);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_Initialized);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_GetExecutor);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_Id);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_VisibleDeviceCount);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_TpuMemoryLimit);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_Init);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_Free);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_PlatformDeviceCount);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_Allocate);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_Deallocate);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_GetAllocatorStats);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_DeviceMemoryUsage);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateStream);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateStream);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateStreamDependency);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_GetStatus);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateEvent);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateEvent);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_PollForEventStatus);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_RecordEvent);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForEvent);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateTimer);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateTimer);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_StartTimer);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_StopTimer);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronousMemcpyToHost);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronousMemcpyFromHost);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_MemcpyToHost);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_MemcpyFromHost);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_EnqueueInfeed);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_DequeueOutfeed);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForInfeedReady);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForOutfeedReady);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_BlockHostUntilDone);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_BlockUntilDoneOrFailed);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_SyncAndForgetFailedStreams);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronizeAllActivity);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_Free);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_Status);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuEvent_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuEvent_Free);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTimer_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTimer_Free);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTimer_Nanoseconds);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTimer_Microseconds);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStatus_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStatus_Create);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStatus_Free);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStatus_Message);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStatus_Code);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStatus_Ok);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Default);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_SetOrdinal);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Free);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuDeviceDescription_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuDeviceDescription_Free);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateDeviceDescription);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_NewDeviceOptions);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_FreeDeviceOptions);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutor_HostCallback);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_Free);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_PlatformId);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_HostShapeToDeviceShape);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralToDeviceAsync);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
||||||
|
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) {
|
||||||
|
auto* node_context_fn = tensorflow::tpu::NodeContextApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create);
|
||||||
|
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free);
|
||||||
|
TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats);
|
||||||
|
TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status SetTpuUtilStructFns(void* library_handle) {
|
||||||
|
auto* util_fn = tensorflow::tpu::UtilApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled);
|
||||||
|
TFTPU_SET_FN(util_fn, TpuCompile_ToTpuShapeRepresentation);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status InitializeTpuStructFns(void* library_handle) {
|
||||||
|
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||||
|
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
|
||||||
|
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
|
||||||
|
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
|
||||||
|
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
|
||||||
|
TF_RETURN_IF_ERROR(SetTpuUtilStructFns(library_handle));
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
|
@ -13,16 +13,23 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
// TODO(frankchn): Rename to `tpu_api_dlsym_initializer` or similar.
|
||||||
|
|
||||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
|
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#define TFTPU_SET_FN(Struct, FnName) \
|
|
||||||
Struct->FnName##Fn = \
|
|
||||||
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));
|
|
||||||
|
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
|
|
||||||
|
#define TFTPU_SET_FN(Struct, FnName) \
|
||||||
|
Struct->FnName##Fn = \
|
||||||
|
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName)); \
|
||||||
|
if (!(Struct->FnName##Fn)) { \
|
||||||
|
LOG(ERROR) << #FnName " not available in this library."; \
|
||||||
|
}
|
||||||
|
|
||||||
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
|
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
|
||||||
// visible methods.
|
// visible methods.
|
||||||
|
@ -30,28 +37,7 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
Status SetTpuInitializeStructFns(void* library_handle) {
|
#include "tensorflow/core/tpu/tpu_library_init_fns.inc"
|
||||||
auto* base_fn = InitializeApiFn();
|
|
||||||
|
|
||||||
TFTPU_SET_FN(base_fn, TfTpu_Initialize);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SetTpuConfigStructFns(void* library_handle) {
|
|
||||||
auto* config_fn = ConfigApiFn();
|
|
||||||
|
|
||||||
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
|
||||||
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
|
||||||
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
|
|
||||||
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
|
||||||
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
|
||||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
|
||||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
|
||||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
TfTpu_BaseFn* InitializeApiFn() {
|
TfTpu_BaseFn* InitializeApiFn() {
|
||||||
static TfTpu_BaseFn base_fn;
|
static TfTpu_BaseFn base_fn;
|
||||||
|
@ -63,19 +49,48 @@ TfTpu_ConfigApiFn* ConfigApiFn() {
|
||||||
return &config_api_fn;
|
return &config_api_fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfTpu_MeshStateApiFn* MeshStateApiFn() {
|
||||||
|
static TfTpu_MeshStateApiFn mesh_state_api_fn;
|
||||||
|
return &mesh_state_api_fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfTpu_CompileApiFn* CompileApiFn() {
|
||||||
|
static TfTpu_CompileApiFn compile_api_fn;
|
||||||
|
return &compile_api_fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfTpu_ExecutorApiFn* ExecutorApiFn() {
|
||||||
|
static TfTpu_ExecutorApiFn executor_api_fn;
|
||||||
|
return &executor_api_fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfTpu_NodeContextApiFn* NodeContextApiFn() {
|
||||||
|
static TfTpu_NodeContextApiFn node_context_api_fn;
|
||||||
|
return &node_context_api_fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfTpu_UtilApiFn* UtilApiFn() {
|
||||||
|
static TfTpu_UtilApiFn util_api_fn;
|
||||||
|
return &util_api_fn;
|
||||||
|
}
|
||||||
|
|
||||||
Status InitializeTpuLibrary(void* library_handle) {
|
Status InitializeTpuLibrary(void* library_handle) {
|
||||||
bool shared_object_loaded = true;
|
bool shared_object_loaded = true;
|
||||||
if (library_handle == nullptr) {
|
if (library_handle == nullptr) {
|
||||||
library_handle = dlopen(nullptr, RTLD_LAZY);
|
library_handle = dlopen(nullptr, RTLD_NOW);
|
||||||
shared_object_loaded = false;
|
shared_object_loaded = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
|
TF_RETURN_IF_ERROR(InitializeTpuStructFns(library_handle));
|
||||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
|
||||||
|
|
||||||
if (shared_object_loaded) {
|
if (shared_object_loaded) {
|
||||||
|
// TODO(frankchn): Make initialization actually work
|
||||||
// Initialize TPU platform when the platform code is loaded from a library.
|
// Initialize TPU platform when the platform code is loaded from a library.
|
||||||
InitializeApiFn()->TfTpu_InitializeFn();
|
// InitializeApiFn()->TfTpu_InitializeFn();
|
||||||
|
|
||||||
|
// We should only register the TPU platform when the library is loaded.
|
||||||
|
// TODO(frankchn): Resolve the circular dependency and register the platform
|
||||||
|
// RegisterTpuPlatform();
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -17,8 +17,13 @@ limitations under the License.
|
||||||
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
|
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
|
||||||
|
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||||
#include "tensorflow/core/tpu/libtftpu.h"
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||||
|
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -26,10 +31,21 @@ namespace tpu {
|
||||||
|
|
||||||
Status InitializeTpuLibrary(void* library_handle);
|
Status InitializeTpuLibrary(void* library_handle);
|
||||||
|
|
||||||
|
// TODO(frankchn): Separate out API functions from the loader.
|
||||||
TfTpu_BaseFn* InitializeApiFn();
|
TfTpu_BaseFn* InitializeApiFn();
|
||||||
|
|
||||||
TfTpu_ConfigApiFn* ConfigApiFn();
|
TfTpu_ConfigApiFn* ConfigApiFn();
|
||||||
|
|
||||||
|
TfTpu_MeshStateApiFn* MeshStateApiFn();
|
||||||
|
|
||||||
|
TfTpu_CompileApiFn* CompileApiFn();
|
||||||
|
|
||||||
|
TfTpu_ExecutorApiFn* ExecutorApiFn();
|
||||||
|
|
||||||
|
TfTpu_NodeContextApiFn* NodeContextApiFn();
|
||||||
|
|
||||||
|
TfTpu_UtilApiFn* UtilApiFn();
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
// LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc)
|
// LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc)
|
||||||
|
|
|
@ -27,6 +27,16 @@ TfTpu_BaseFn* InitializeApiFn() { return nullptr; }
|
||||||
|
|
||||||
TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; }
|
TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; }
|
||||||
|
|
||||||
|
TfTpu_MeshStateApiFn* MeshStateApiFn() { return nullptr; }
|
||||||
|
|
||||||
|
TfTpu_CompileApiFn* CompileApiFn() { return nullptr; }
|
||||||
|
|
||||||
|
TfTpu_ExecutorApiFn* ExecutorApiFn() { return nullptr; }
|
||||||
|
|
||||||
|
TfTpu_NodeContextApiFn* NodeContextApiFn() { return nullptr; }
|
||||||
|
|
||||||
|
TfTpu_UtilApiFn* UtilApiFn() { return nullptr; }
|
||||||
|
|
||||||
Status InitializeTpuLibrary(void* library_handle) {
|
Status InitializeTpuLibrary(void* library_handle) {
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"Loading TPU library is not supported on Windows.");
|
"Loading TPU library is not supported on Windows.");
|
||||||
|
|
|
@ -11,20 +11,25 @@ package(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_executor_c_api_hdrs",
|
name = "tpu_executor_c_api_hdrs",
|
||||||
hdrs = ["tpu_executor_c_api.h"],
|
hdrs = ["tpu_executor_c_api.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:tf_attrtype",
|
"//tensorflow/c:tf_attrtype",
|
||||||
"//tensorflow/c:tf_datatype",
|
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/core/tpu:libtftpu_header",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_ops_common_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_ops_common_c_api_hdrs",
|
||||||
],
|
],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_node_context_c_api_hdrs",
|
name = "tpu_node_context_c_api_hdrs",
|
||||||
hdrs = ["tpu_node_context_c_api.h"],
|
hdrs = ["tpu_node_context_c_api.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu:libtftpu_header",
|
||||||
],
|
],
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -65,6 +70,7 @@ cc_library(
|
||||||
":status_helper",
|
":status_helper",
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
":tpu_stream_interface",
|
":tpu_stream_interface",
|
||||||
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
"//tensorflow/stream_executor:stream",
|
"//tensorflow/stream_executor:stream",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -75,6 +81,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
"//tensorflow/stream_executor:stream",
|
"//tensorflow/stream_executor:stream",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -94,6 +101,7 @@ cc_library(
|
||||||
":tpu_timer",
|
":tpu_timer",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
"//tensorflow/stream_executor:stream",
|
"//tensorflow/stream_executor:stream",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
@ -143,6 +151,7 @@ cc_library(
|
||||||
"//tensorflow/compiler/xla/service:stream_pool",
|
"//tensorflow/compiler/xla/service:stream_pool",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
@ -160,6 +169,7 @@ cc_library(
|
||||||
":tpu_platform_interface",
|
":tpu_platform_interface",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
"//tensorflow/stream_executor:stream",
|
"//tensorflow/stream_executor:stream",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
|
@ -191,6 +201,7 @@ cc_library(
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
|
"//tensorflow/core/tpu:tpu_library_loader",
|
||||||
"//tensorflow/stream_executor:stream",
|
"//tensorflow/stream_executor:stream",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -217,6 +228,7 @@ cc_library(
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
"//tensorflow/stream_executor:multi_platform_manager",
|
"//tensorflow/stream_executor:multi_platform_manager",
|
||||||
"//tensorflow/stream_executor:stream_executor_headers",
|
"//tensorflow/stream_executor:stream_executor_headers",
|
||||||
|
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||||
|
@ -33,63 +34,68 @@ namespace {
|
||||||
using ::stream_executor::port::Status;
|
using ::stream_executor::port::Status;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); }
|
TpuExecutor::~TpuExecutor() {
|
||||||
|
tpu::ExecutorApiFn()->TpuExecutor_FreeFn(executor_);
|
||||||
|
}
|
||||||
|
|
||||||
Status TpuExecutor::Init(int device_ordinal,
|
Status TpuExecutor::Init(int device_ordinal,
|
||||||
::stream_executor::DeviceOptions device_options) {
|
::stream_executor::DeviceOptions device_options) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
SE_DeviceOptions* options =
|
SE_DeviceOptions* options =
|
||||||
TpuExecutor_NewDeviceOptions(device_options.flags());
|
tpu::ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(
|
||||||
TpuExecutor_Init(executor_, device_ordinal, options, status.c_status);
|
device_options.flags());
|
||||||
TpuExecutor_FreeDeviceOptions(options);
|
tpu::ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options,
|
||||||
|
status.c_status);
|
||||||
|
tpu::ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
int TpuExecutor::PlatformDeviceCount() {
|
int TpuExecutor::PlatformDeviceCount() {
|
||||||
return TpuExecutor_PlatformDeviceCount(executor_);
|
return tpu::ExecutorApiFn()->TpuExecutor_PlatformDeviceCountFn(executor_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuExecutor::SyncAndForgetFailedStreams() {
|
void TpuExecutor::SyncAndForgetFailedStreams() {
|
||||||
TpuExecutor_SyncAndForgetFailedStreams(executor_);
|
tpu::ExecutorApiFn()->TpuExecutor_SyncAndForgetFailedStreamsFn(executor_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuExecutor::SynchronizeAllActivity() {
|
bool TpuExecutor::SynchronizeAllActivity() {
|
||||||
return TpuExecutor_SynchronizeAllActivity(executor_);
|
return tpu::ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
|
Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_BlockHostUntilDone(
|
tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn(
|
||||||
executor_, stream_map().at(stream->implementation()), status.c_status);
|
executor_, stream_map().at(stream->implementation()), status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuExecutor::BlockUntilDoneOrFailed() {
|
Status TpuExecutor::BlockUntilDoneOrFailed() {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status);
|
tpu::ExecutorApiFn()->TpuExecutor_BlockUntilDoneOrFailedFn(executor_,
|
||||||
|
status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuExecutor::GetStatus(Stream* stream) {
|
Status TpuExecutor::GetStatus(Stream* stream) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()),
|
tpu::ExecutorApiFn()->TpuExecutor_GetStatusFn(
|
||||||
status.c_status);
|
executor_, stream_map().at(stream->implementation()), status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuExecutor::AllocateStream(Stream* stream) {
|
bool TpuExecutor::AllocateStream(Stream* stream) {
|
||||||
return TpuExecutor_AllocateStream(executor_,
|
return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
|
||||||
stream_map().at(stream->implementation()));
|
executor_, stream_map().at(stream->implementation()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuExecutor::DeallocateStream(Stream* stream) {
|
void TpuExecutor::DeallocateStream(Stream* stream) {
|
||||||
TpuExecutor_DeallocateStream(executor_,
|
tpu::ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(
|
||||||
stream_map().at(stream->implementation()));
|
executor_, stream_map().at(stream->implementation()));
|
||||||
stream_map().erase(stream->implementation());
|
stream_map().erase(stream->implementation());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
|
bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
|
||||||
return TpuExecutor_CreateStreamDependency(
|
return tpu::ExecutorApiFn()->TpuExecutor_CreateStreamDependencyFn(
|
||||||
executor_, stream_map().at(dependent->implementation()),
|
executor_, stream_map().at(dependent->implementation()),
|
||||||
stream_map().at(other->implementation()));
|
stream_map().at(other->implementation()));
|
||||||
}
|
}
|
||||||
|
@ -104,15 +110,15 @@ bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
|
||||||
void TpuExecutor::DeallocateTimer(Timer* timer) {}
|
void TpuExecutor::DeallocateTimer(Timer* timer) {}
|
||||||
|
|
||||||
bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||||
return TpuExecutor_StartTimer(executor_,
|
return tpu::ExecutorApiFn()->TpuExecutor_StartTimerFn(
|
||||||
stream_map().at(stream->implementation()),
|
executor_, stream_map().at(stream->implementation()),
|
||||||
timer_map_.at(timer->implementation()));
|
timer_map_.at(timer->implementation()));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||||
return TpuExecutor_StopTimer(executor_,
|
return tpu::ExecutorApiFn()->TpuExecutor_StopTimerFn(
|
||||||
stream_map().at(stream->implementation()),
|
executor_, stream_map().at(stream->implementation()),
|
||||||
timer_map_.at(timer->implementation()));
|
timer_map_.at(timer->implementation()));
|
||||||
}
|
}
|
||||||
|
|
||||||
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
|
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
|
||||||
|
@ -148,7 +154,7 @@ Status TpuExecutor::WaitForEvent(Stream* stream,
|
||||||
// Called by Timer::Timer
|
// Called by Timer::Timer
|
||||||
std::unique_ptr<::stream_executor::internal::TimerInterface>
|
std::unique_ptr<::stream_executor::internal::TimerInterface>
|
||||||
TpuExecutor::GetTimerImplementation() {
|
TpuExecutor::GetTimerImplementation() {
|
||||||
SE_Timer* tpu_timer = TpuTimer_New(executor_);
|
SE_Timer* tpu_timer = tpu::ExecutorApiFn()->TpuTimer_NewFn(executor_);
|
||||||
auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
|
auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
|
||||||
timer_map_[ptr.get()] = tpu_timer;
|
timer_map_[ptr.get()] = tpu_timer;
|
||||||
return ptr;
|
return ptr;
|
||||||
|
@ -157,7 +163,7 @@ TpuExecutor::GetTimerImplementation() {
|
||||||
// Called by Stream::Stream
|
// Called by Stream::Stream
|
||||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||||
TpuExecutor::GetStreamImplementation() {
|
TpuExecutor::GetStreamImplementation() {
|
||||||
SE_Stream* tpu_stream = TpuStream_New(executor_);
|
SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
|
||||||
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
|
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
|
||||||
stream_map()[ptr.get()] = tpu_stream;
|
stream_map()[ptr.get()] = tpu_stream;
|
||||||
return ptr;
|
return ptr;
|
||||||
|
@ -166,34 +172,35 @@ TpuExecutor::GetStreamImplementation() {
|
||||||
// Called by Event::Event
|
// Called by Event::Event
|
||||||
std::unique_ptr<::stream_executor::internal::EventInterface>
|
std::unique_ptr<::stream_executor::internal::EventInterface>
|
||||||
TpuExecutor::CreateEventImplementation() {
|
TpuExecutor::CreateEventImplementation() {
|
||||||
SE_Event* tpu_event = TpuEvent_New(executor_);
|
SE_Event* tpu_event = tpu::ExecutorApiFn()->TpuEvent_NewFn(executor_);
|
||||||
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
|
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
|
||||||
event_map()[ptr.get()] = tpu_event;
|
event_map()[ptr.get()] = tpu_event;
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
|
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
|
||||||
TpuExecutor_Allocate(executor_, size, memory_space);
|
executor_, size, memory_space);
|
||||||
return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
|
return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
|
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base =
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
|
||||||
TpuExecutor_Deallocate(executor_, &se_base);
|
tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
|
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base =
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
|
||||||
TpuExecutor_Deallocate(executor_, &se_base);
|
tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
||||||
int64_t _free;
|
int64_t _free;
|
||||||
int64_t _total;
|
int64_t _total;
|
||||||
if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) {
|
if (tpu::ExecutorApiFn()->TpuExecutor_DeviceMemoryUsageFn(executor_, &_free,
|
||||||
|
&_total)) {
|
||||||
*free = _free;
|
*free = _free;
|
||||||
*total = _total;
|
*total = _total;
|
||||||
return true;
|
return true;
|
||||||
|
@ -204,7 +211,8 @@ bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
||||||
absl::optional<stream_executor::AllocatorStats>
|
absl::optional<stream_executor::AllocatorStats>
|
||||||
TpuExecutor::GetAllocatorStats() {
|
TpuExecutor::GetAllocatorStats() {
|
||||||
SE_AllocatorStats c_stats;
|
SE_AllocatorStats c_stats;
|
||||||
if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) {
|
if (tpu::ExecutorApiFn()->TpuExecutor_GetAllocatorStatsFn(executor_,
|
||||||
|
&c_stats)) {
|
||||||
::stream_executor::AllocatorStats stats;
|
::stream_executor::AllocatorStats stats;
|
||||||
stats.num_allocs = c_stats.num_allocs;
|
stats.num_allocs = c_stats.num_allocs;
|
||||||
stats.bytes_in_use = c_stats.bytes_in_use;
|
stats.bytes_in_use = c_stats.bytes_in_use;
|
||||||
|
@ -226,31 +234,33 @@ TpuExecutor::GetAllocatorStats() {
|
||||||
|
|
||||||
Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
|
Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index,
|
tpu::ExecutorApiFn()->TpuExecutor_WaitForInfeedReadyFn(
|
||||||
status.c_status);
|
executor_, infeed_queue_index, status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
|
Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index,
|
tpu::ExecutorApiFn()->TpuExecutor_WaitForOutfeedReadyFn(
|
||||||
status.c_status);
|
executor_, outfeed_queue_index, status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
|
void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
|
||||||
absl::Span<uint8> bytes, StatusCallback done) {
|
absl::Span<uint8> bytes, StatusCallback done) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(),
|
tpu::ExecutorApiFn()->TpuExecutor_DequeueOutfeedFn(
|
||||||
bytes.size(), status.c_status);
|
executor_, outfeed_queue_index, bytes.data(), bytes.size(),
|
||||||
|
status.c_status);
|
||||||
done(status.status());
|
done(status.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
|
Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
|
||||||
absl::Span<const uint8> bytes) {
|
absl::Span<const uint8> bytes) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(),
|
tpu::ExecutorApiFn()->TpuExecutor_EnqueueInfeedFn(
|
||||||
bytes.size(), status.c_status);
|
executor_, infeed_queue_index, bytes.data(), bytes.size(),
|
||||||
|
status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,9 +269,9 @@ bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
|
||||||
uint64 size) {
|
uint64 size) {
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base =
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
||||||
return TpuExecutor_MemcpyToHost(executor_,
|
return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
|
||||||
stream_map().at(stream->implementation()),
|
executor_, stream_map().at(stream->implementation()), host_dst, &se_base,
|
||||||
host_dst, &se_base, size);
|
size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuExecutor::Memcpy(Stream* stream,
|
bool TpuExecutor::Memcpy(Stream* stream,
|
||||||
|
@ -269,9 +279,9 @@ bool TpuExecutor::Memcpy(Stream* stream,
|
||||||
const void* host_src, uint64 size) {
|
const void* host_src, uint64 size) {
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base =
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
||||||
return TpuExecutor_MemcpyFromHost(executor_,
|
return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
|
||||||
stream_map().at(stream->implementation()),
|
executor_, stream_map().at(stream->implementation()), &se_base, host_src,
|
||||||
&se_base, host_src, size);
|
size);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TpuExecutor::SynchronousMemcpy(
|
Status TpuExecutor::SynchronousMemcpy(
|
||||||
|
@ -280,8 +290,8 @@ Status TpuExecutor::SynchronousMemcpy(
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base =
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
||||||
TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size,
|
tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
|
||||||
status.c_status);
|
executor_, &se_base, host_src, size, status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,8 +301,8 @@ Status TpuExecutor::SynchronousMemcpy(
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
SE_DeviceMemoryBase se_base =
|
SE_DeviceMemoryBase se_base =
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
||||||
TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size,
|
tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
|
||||||
status.c_status);
|
executor_, host_dst, &se_base, size, status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,8 +326,8 @@ struct HostCallbackContext {
|
||||||
SE_Status* HostCallbackTrampoline(void* ctx) {
|
SE_Status* HostCallbackTrampoline(void* ctx) {
|
||||||
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
|
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
|
||||||
Status status = host_ctx->callback();
|
Status status = host_ctx->callback();
|
||||||
SE_Status* c_status =
|
SE_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
|
||||||
TpuStatus_Create(status.code(), status.error_message().c_str());
|
status.code(), status.error_message().c_str());
|
||||||
delete host_ctx;
|
delete host_ctx;
|
||||||
return c_status;
|
return c_status;
|
||||||
}
|
}
|
||||||
|
@ -325,18 +335,21 @@ SE_Status* HostCallbackTrampoline(void* ctx) {
|
||||||
bool TpuExecutor::HostCallback(Stream* stream,
|
bool TpuExecutor::HostCallback(Stream* stream,
|
||||||
std::function<Status()> callback) {
|
std::function<Status()> callback) {
|
||||||
HostCallbackContext* ctx = new HostCallbackContext{callback};
|
HostCallbackContext* ctx = new HostCallbackContext{callback};
|
||||||
return TpuExecutor_HostCallback(executor_,
|
return tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn(
|
||||||
stream_map().at(stream->implementation()),
|
executor_, stream_map().at(stream->implementation()),
|
||||||
&HostCallbackTrampoline, ctx);
|
&HostCallbackTrampoline, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
|
TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
|
||||||
TpuExecutor::CreateDeviceDescription() const {
|
TpuExecutor::CreateDeviceDescription() const {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
SE_DeviceDescription* description = TpuDeviceDescription_New();
|
SE_DeviceDescription* description =
|
||||||
auto cleanup = tensorflow::gtl::MakeCleanup(
|
tpu::ExecutorApiFn()->TpuDeviceDescription_NewFn();
|
||||||
[description]() { TpuDeviceDescription_Free(description); });
|
auto cleanup = tensorflow::gtl::MakeCleanup([description]() {
|
||||||
TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status);
|
tpu::ExecutorApiFn()->TpuDeviceDescription_FreeFn(description);
|
||||||
|
});
|
||||||
|
tpu::ExecutorApiFn()->TpuExecutor_CreateDeviceDescriptionFn(
|
||||||
|
executor_, description, status.c_status);
|
||||||
if (status.status().ok()) {
|
if (status.status().ok()) {
|
||||||
stream_executor::internal::DeviceDescriptionBuilder builder;
|
stream_executor::internal::DeviceDescriptionBuilder builder;
|
||||||
CHECK_NE(description->device_vendor, nullptr);
|
CHECK_NE(description->device_vendor, nullptr);
|
||||||
|
|
|
@ -20,9 +20,9 @@ limitations under the License.
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "tensorflow/c/tf_attrtype.h"
|
#include "tensorflow/c/tf_attrtype.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
|
|
||||||
typedef struct SE_Platform SE_Platform;
|
typedef struct SE_Platform SE_Platform;
|
||||||
typedef struct SE_StreamExecutor SE_StreamExecutor;
|
typedef struct SE_StreamExecutor SE_StreamExecutor;
|
||||||
|
@ -292,6 +292,96 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
|
||||||
|
|
||||||
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
||||||
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||||
|
|
||||||
|
struct TfTpu_ExecutorApiFn {
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Initialize);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Initialized);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetExecutor);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Id);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_VisibleDeviceCount);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_TpuMemoryLimit);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Init);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PlatformDeviceCount);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Allocate);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Deallocate);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetAllocatorStats);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeviceMemoryUsage);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateStream);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateStream);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateStreamDependency);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetStatus);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateEvent);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateEvent);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PollForEventStatus);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_RecordEvent);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForEvent);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateTimer);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateTimer);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_StartTimer);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_StopTimer);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronousMemcpyToHost);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronousMemcpyFromHost);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_MemcpyToHost);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_MemcpyFromHost);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_EnqueueInfeed);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DequeueOutfeed);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForInfeedReady);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForOutfeedReady);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockHostUntilDone);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockUntilDoneOrFailed);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SyncAndForgetFailedStreams);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronizeAllActivity);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_Free);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Nanoseconds);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Microseconds);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Create);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Message);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Ok);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Default);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_SetOrdinal);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Free);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_Free);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateDeviceDescription);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_NewDeviceOptions);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_FreeDeviceOptions);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_HostCallback);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_PlatformId);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_HostShapeToDeviceShape);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralToDeviceAsync);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
||||||
|
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// extern "C"
|
// extern "C"
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/service/backend.h"
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||||
|
@ -32,15 +33,18 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
XLA_TpuNodeContext* node_context =
|
XLA_TpuNodeContext* node_context =
|
||||||
TpuNodeContext_Create(device_ordinal, status.c_status);
|
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
|
||||||
|
status.c_status);
|
||||||
if (!status.status().ok()) {
|
if (!status.status().ok()) {
|
||||||
TpuNodeContext_Free(node_context);
|
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
|
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
|
TpuNodeContext::~TpuNodeContext() {
|
||||||
|
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
|
||||||
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
Status TpuNodeContext::Initialize(int device_ordinal) {
|
Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||||
|
@ -52,14 +56,14 @@ Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||||
/* static */
|
/* static */
|
||||||
Status TpuNodeContext::StopChipHeartbeats() {
|
Status TpuNodeContext::StopChipHeartbeats() {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuNodeContext_StopChipHeartbeats(status.c_status);
|
tpu::NodeContextApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
Status TpuNodeContext::CloseTpuHost() {
|
Status TpuNodeContext::CloseTpuHost() {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuNodeContext_CloseTpuHost(status.c_status);
|
tpu::NodeContextApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
|
|
||||||
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
|
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
||||||
SE_Status* status);
|
SE_Status* status);
|
||||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||||
|
@ -28,4 +31,13 @@ void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
||||||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||||
|
|
||||||
|
} // extern "C"
|
||||||
|
|
||||||
|
struct TfTpu_NodeContextApiFn {
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||||
|
@ -30,7 +31,9 @@ using Status = ::stream_executor::port::Status;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||||
|
|
||||||
TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); }
|
TpuPlatform::TpuPlatform() {
|
||||||
|
platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
|
||||||
|
}
|
||||||
|
|
||||||
TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
|
TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
|
||||||
return tpu_registered_platform;
|
return tpu_registered_platform;
|
||||||
|
@ -53,8 +56,8 @@ Status TpuPlatform::Initialize(
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuPlatform_Initialize(platform_, options_size, options_key, options_value,
|
tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
|
||||||
status.c_status);
|
platform_, options_size, options_key, options_value, status.c_status);
|
||||||
|
|
||||||
free(options_key);
|
free(options_key);
|
||||||
free(options_value);
|
free(options_value);
|
||||||
|
@ -62,10 +65,16 @@ Status TpuPlatform::Initialize(
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); }
|
bool TpuPlatform::Initialized() const {
|
||||||
|
return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuPlatform::~TpuPlatform() {
|
||||||
|
tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
|
||||||
|
}
|
||||||
|
|
||||||
int TpuPlatform::VisibleDeviceCount() const {
|
int TpuPlatform::VisibleDeviceCount() const {
|
||||||
return TpuPlatform_VisibleDeviceCount(platform_);
|
return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
|
StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
|
||||||
|
@ -77,14 +86,16 @@ StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
|
||||||
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
|
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
|
||||||
TpuPlatform::GetUncachedExecutor(
|
TpuPlatform::GetUncachedExecutor(
|
||||||
const ::stream_executor::StreamExecutorConfig& config) {
|
const ::stream_executor::StreamExecutorConfig& config) {
|
||||||
SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default();
|
SE_StreamExecutorConfig* c_config =
|
||||||
|
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
|
||||||
|
|
||||||
TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal);
|
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
|
||||||
|
config.ordinal);
|
||||||
|
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
SE_StreamExecutor* executor =
|
SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
|
||||||
TpuPlatform_GetExecutor(platform_, c_config, status.c_status);
|
platform_, c_config, status.c_status);
|
||||||
TpuStreamExecutorConfig_Free(c_config);
|
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
@ -103,27 +114,24 @@ const std::string& TpuPlatform::Name() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 TpuPlatform::TpuMemoryLimit() {
|
int64 TpuPlatform::TpuMemoryLimit() {
|
||||||
return TpuPlatform_TpuMemoryLimit(platform_);
|
return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
|
bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
|
||||||
return TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(platform_);
|
return tpu::ExecutorApiFn()
|
||||||
|
->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterTpuPlatform() {
|
||||||
|
static bool tpu_platform_registered = false;
|
||||||
|
if (!tpu_platform_registered) {
|
||||||
|
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
|
||||||
|
std::unique_ptr<stream_executor::Platform> platform(
|
||||||
|
tensorflow::tpu_registered_platform);
|
||||||
|
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||||
|
std::move(platform)));
|
||||||
|
tpu_platform_registered = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
void RegisterTpuPlatform() {
|
|
||||||
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
|
|
||||||
std::unique_ptr<stream_executor::Platform> platform(
|
|
||||||
tensorflow::tpu_registered_platform);
|
|
||||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
|
||||||
std::move(platform)));
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
|
|
||||||
|
|
||||||
// Note that module initialization sequencing is not supported in the
|
|
||||||
// open-source project, so this will be a no-op there.
|
|
||||||
REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
|
|
||||||
REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
|
|
||||||
tpu_platform);
|
|
||||||
|
|
|
@ -60,9 +60,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||||
|
|
||||||
bool ShouldRegisterTpuDeviceToDeviceCopy() override;
|
bool ShouldRegisterTpuDeviceToDeviceCopy() override;
|
||||||
|
|
||||||
bool Initialized() const override {
|
bool Initialized() const override;
|
||||||
return TpuPlatform_Initialized(platform_);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Initialize(
|
Status Initialize(
|
||||||
const std::map<std::string, std::string>& platform_options) override;
|
const std::map<std::string, std::string>& platform_options) override;
|
||||||
|
@ -124,6 +122,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||||
EventMap event_map_;
|
EventMap event_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void RegisterTpuPlatform();
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
|
@ -27,23 +28,27 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||||
using Status = stream_executor::port::Status;
|
using Status = stream_executor::port::Status;
|
||||||
|
|
||||||
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
|
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
|
||||||
~TpuStream() override { TpuStream_Free(stream_); }
|
~TpuStream() override {
|
||||||
|
tensorflow::tpu::ExecutorApiFn()->TpuStream_FreeFn(stream_);
|
||||||
|
}
|
||||||
|
|
||||||
bool IsSameSharedMemoryLocation(
|
bool IsSameSharedMemoryLocation(
|
||||||
tensorflow::tpu::TpuStreamInterface* other) override {
|
tensorflow::tpu::TpuStreamInterface* other) override {
|
||||||
return TpuStream_IsSameSharedMemoryLocation(
|
return tensorflow::tpu::ExecutorApiFn()
|
||||||
stream_, static_cast<TpuStream*>(other)->stream_);
|
->TpuStream_IsSameSharedMemoryLocationFn(
|
||||||
|
stream_, static_cast<TpuStream*>(other)->stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EnqueueOnTpuDeviceSendRecvLocal(
|
Status EnqueueOnTpuDeviceSendRecvLocal(
|
||||||
stream_executor::DeviceMemoryBase send_buffer,
|
stream_executor::DeviceMemoryBase send_buffer,
|
||||||
stream_executor::DeviceMemoryBase recv_buffer) override {
|
stream_executor::DeviceMemoryBase recv_buffer) override {
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
TpuStream_TpuEnqueueOnDeviceSendRecvLocal(
|
tensorflow::tpu::ExecutorApiFn()
|
||||||
stream_,
|
->TpuStream_TpuEnqueueOnDeviceSendRecvLocalFn(
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
|
stream_,
|
||||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
|
||||||
status.c_status);
|
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
|
||||||
|
status.c_status);
|
||||||
return status.status();
|
return status.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +59,9 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
||||||
class TpuEvent : public ::stream_executor::internal::EventInterface {
|
class TpuEvent : public ::stream_executor::internal::EventInterface {
|
||||||
public:
|
public:
|
||||||
explicit TpuEvent(SE_Event* event) : event_(event) {}
|
explicit TpuEvent(SE_Event* event) : event_(event) {}
|
||||||
~TpuEvent() override { TpuEvent_Free(event_); }
|
~TpuEvent() override {
|
||||||
|
tensorflow::tpu::ExecutorApiFn()->TpuEvent_FreeFn(event_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SE_Event* event_;
|
SE_Event* event_;
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
|
||||||
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
|
|
||||||
|
@ -25,9 +26,15 @@ namespace tensorflow {
|
||||||
class TpuTimer : public ::stream_executor::internal::TimerInterface {
|
class TpuTimer : public ::stream_executor::internal::TimerInterface {
|
||||||
public:
|
public:
|
||||||
explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
|
explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
|
||||||
~TpuTimer() override { TpuTimer_Free(timer_); }
|
~TpuTimer() override {
|
||||||
uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); }
|
tensorflow::tpu::ExecutorApiFn()->TpuTimer_FreeFn(timer_);
|
||||||
uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); }
|
}
|
||||||
|
uint64 Microseconds() const override {
|
||||||
|
return tensorflow::tpu::ExecutorApiFn()->TpuTimer_MicrosecondsFn(timer_);
|
||||||
|
}
|
||||||
|
uint64 Nanoseconds() const override {
|
||||||
|
return tensorflow::tpu::ExecutorApiFn()->TpuTimer_NanosecondsFn(timer_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SE_Timer* timer_;
|
SE_Timer* timer_;
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
@ -29,10 +30,12 @@ namespace tensorflow {
|
||||||
using Status = stream_executor::port::Status;
|
using Status = stream_executor::port::Status;
|
||||||
|
|
||||||
TpuTransferManager::TpuTransferManager() {
|
TpuTransferManager::TpuTransferManager() {
|
||||||
manager_ = TpuTransferManager_New();
|
manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); }
|
TpuTransferManager::~TpuTransferManager() {
|
||||||
|
tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
|
||||||
|
}
|
||||||
|
|
||||||
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
|
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
|
||||||
return TpuPlatform::kId;
|
return TpuPlatform::kId;
|
||||||
|
@ -45,8 +48,8 @@ xla::Shape TpuTransferManager::HostShapeToDeviceShape(
|
||||||
|
|
||||||
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
|
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
|
||||||
|
|
||||||
TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape,
|
tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
|
||||||
&c_device_shape);
|
manager_, &c_host_shape, &c_device_shape);
|
||||||
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
|
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
|
||||||
TpuConversions::CShapeCleanup(&c_host_shape);
|
TpuConversions::CShapeCleanup(&c_host_shape);
|
||||||
TpuConversions::CShapeCleanup(&c_device_shape);
|
TpuConversions::CShapeCleanup(&c_device_shape);
|
||||||
|
@ -66,7 +69,7 @@ Status TpuTransferManager::TransferLiteralToDeviceAsync(
|
||||||
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
|
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
|
||||||
&c_device_buffer);
|
&c_device_buffer);
|
||||||
|
|
||||||
TpuTransferManager_TransferLiteralToDeviceAsync(
|
tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
|
||||||
manager_,
|
manager_,
|
||||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||||
stream->implementation()),
|
stream->implementation()),
|
||||||
|
@ -112,7 +115,7 @@ void TpuTransferManager::TransferLiteralFromDevice(
|
||||||
XLA_Literal c_literal;
|
XLA_Literal c_literal;
|
||||||
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
|
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
|
||||||
|
|
||||||
TpuTransferManager_TransferLiteralFromDevice(
|
tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
|
||||||
manager_,
|
manager_,
|
||||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||||
stream->implementation()),
|
stream->implementation()),
|
||||||
|
@ -127,7 +130,8 @@ int64 TpuTransferManager::GetByteSizeRequirement(
|
||||||
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
||||||
|
|
||||||
int64 size_in_bytes =
|
int64 size_in_bytes =
|
||||||
TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape);
|
tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
|
||||||
|
manager_, &c_shape);
|
||||||
|
|
||||||
TpuConversions::CShapeCleanup(&c_shape);
|
TpuConversions::CShapeCleanup(&c_shape);
|
||||||
return size_in_bytes;
|
return size_in_bytes;
|
||||||
|
@ -151,7 +155,7 @@ Status TpuTransferManager::WriteSingleTupleIndexTable(
|
||||||
region->payload()};
|
region->payload()};
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
|
|
||||||
TpuTransferManager_WriteSingleTupleIndexTable(
|
tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
|
||||||
manager_,
|
manager_,
|
||||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||||
stream->implementation()),
|
stream->implementation()),
|
||||||
|
|
|
@ -2899,6 +2899,13 @@ def if_mlir(if_true, if_false = []):
|
||||||
"//conditions:default": if_false,
|
"//conditions:default": if_false,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def if_tpu(if_true, if_false = []):
|
||||||
|
"""Shorthand for select()ing whether to build for TPUs."""
|
||||||
|
return select({
|
||||||
|
str(Label("//tensorflow:with_tpu_support")): if_true,
|
||||||
|
"//conditions:default": if_false,
|
||||||
|
})
|
||||||
|
|
||||||
def tfcompile_target_cpu():
|
def tfcompile_target_cpu():
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue