Add TPU configuration ops to default TensorFlow build

PiperOrigin-RevId: 317133514
Change-Id: I33bc6d7fdbba5915bd0d1291d4e086139c07eb14
This commit is contained in:
Frank Chen 2020-06-18 10:39:58 -07:00 committed by TensorFlower Gardener
parent d1157c976b
commit cf00e559d7
26 changed files with 627 additions and 184 deletions

View File

@ -39,6 +39,7 @@
# #
# Feature and Third party library support options: # Feature and Third party library support options:
# xla: Build TF with XLA # xla: Build TF with XLA
# tpu: Build TF with TPU support
# using_cuda: CUDA is available to build system. # using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support. # cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm). # rocm: Build with AMD GPU support (rocm).
@ -180,6 +181,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 # AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
build:dbg --copt -DDEBUG_BUILD build:dbg --copt -DDEBUG_BUILD
# Config to build TPU backend
build:tpu --define=with_tpu_support=true
build:tensorrt --action_env TF_NEED_TENSORRT=1 build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain

View File

@ -467,6 +467,13 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# This flag enables experimental TPU support
config_setting(
name = "with_tpu_support",
values = {"define": "with_tpu_support=true"},
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes # Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later. # it easier to combine settings later.
selects.config_setting_group( selects.config_setting_group(

View File

@ -72,6 +72,7 @@ load(
"if_ios", "if_ios",
"if_mobile", "if_mobile",
"if_not_windows", "if_not_windows",
"if_tpu",
"tf_android_core_proto_headers", "tf_android_core_proto_headers",
"tf_cc_test", "tf_cc_test",
"tf_cc_test_mkl", "tf_cc_test_mkl",
@ -1093,6 +1094,8 @@ cc_library(
]) + if_tensorrt([ ]) + if_tensorrt([
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels", "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels", "//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
]) + if_tpu([
"//tensorflow/core/tpu/kernels",
]), ]),
) )

View File

@ -103,6 +103,7 @@ cc_library(
":libtftpu_header", ":libtftpu_header",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
], ],
alwayslink = True,
) )
cc_library( cc_library(
@ -116,7 +117,20 @@ cc_library(
deps = [ deps = [
":libtftpu_header", ":libtftpu_header",
":tpu_config_c_api", ":tpu_config_c_api",
":tpu_library_init_fns",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status", "//tensorflow/core/platform:status",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_hdrs",
], ],
) )
cc_library(
name = "tpu_library_init_fns",
hdrs = ["tpu_library_init_fns.inc"],
visibility = ["//visibility:public"],
)

View File

@ -3,6 +3,10 @@ load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
"tf_proto_library_cc", "tf_proto_library_cc",
) )
load(
"//tensorflow:tensorflow.bzl",
"tf_kernel_library",
)
package( package(
default_visibility = [ default_visibility = [
@ -12,6 +16,12 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
tf_kernel_library(
name = "kernels",
visibility = ["//visibility:public"],
deps = [":tpu_configuration_ops"],
)
cc_library( cc_library(
name = "tpu_compile_op_common", name = "tpu_compile_op_common",
srcs = ["tpu_compile_op_common.cc"], srcs = ["tpu_compile_op_common.cc"],
@ -50,7 +60,7 @@ cc_library(
hdrs = ["tpu_compile_op_options.h"], hdrs = ["tpu_compile_op_options.h"],
) )
cc_library( tf_kernel_library(
name = "tpu_configuration_ops", name = "tpu_configuration_ops",
srcs = ["tpu_configuration_ops.cc"], srcs = ["tpu_configuration_ops.cc"],
hdrs = ["tpu_configuration_ops.h"], hdrs = ["tpu_configuration_ops.h"],
@ -75,12 +85,13 @@ cc_library(
name = "tpu_compile_c_api_hdrs", name = "tpu_compile_c_api_hdrs",
hdrs = ["tpu_compile_c_api.h"], hdrs = ["tpu_compile_c_api.h"],
deps = [ deps = [
":tpu_mesh_state_c_api", ":tpu_mesh_state_c_api_hdrs",
":tpu_ops_common_c_api_hdrs", ":tpu_ops_common_c_api_hdrs",
":tpu_program_c_api_hdrs", ":tpu_program_c_api_hdrs",
"//tensorflow/c:tf_datatype", "//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:proto_helper", "//tensorflow/stream_executor/tpu:proto_helper",
], ],
alwayslink = True,
) )
tf_proto_library_cc( tf_proto_library_cc(
@ -197,8 +208,10 @@ cc_library(
) )
cc_library( cc_library(
name = "tpu_mesh_state_c_api", name = "tpu_mesh_state_c_api_hdrs",
hdrs = ["tpu_mesh_state_c_api.h"], hdrs = ["tpu_mesh_state_c_api.h"],
deps = ["//tensorflow/core/tpu:libtftpu_header"],
alwayslink = True,
) )
cc_library( cc_library(
@ -207,12 +220,11 @@ cc_library(
hdrs = ["tpu_mesh_state_interface.h"], hdrs = ["tpu_mesh_state_interface.h"],
deps = [ deps = [
":tpu_compile_c_api_hdrs", ":tpu_compile_c_api_hdrs",
":tpu_mesh_state_c_api", ":tpu_mesh_state_c_api_hdrs",
"//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/tpu:tpu_config_c_api", "//tensorflow/core/tpu:tpu_library_loader",
], ],
) )
@ -371,13 +383,16 @@ cc_library(
name = "tpu_util_c_api_hdrs", name = "tpu_util_c_api_hdrs",
hdrs = ["tpu_util_c_api.h"], hdrs = ["tpu_util_c_api.h"],
deps = [ deps = [
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:proto_helper", "//tensorflow/stream_executor/tpu:proto_helper",
], ],
alwayslink = True,
) )
cc_library( cc_library(
name = "tpu_ops_common_c_api_hdrs", name = "tpu_ops_common_c_api_hdrs",
hdrs = ["tpu_ops_common_c_api.h"], hdrs = ["tpu_ops_common_c_api.h"],
alwayslink = True,
) )
cc_library( cc_library(
@ -387,6 +402,7 @@ cc_library(
":tpu_ops_common_c_api_hdrs", ":tpu_ops_common_c_api_hdrs",
"//tensorflow/stream_executor/tpu:proto_helper", "//tensorflow/stream_executor/tpu:proto_helper",
], ],
alwayslink = True,
) )
cc_library( cc_library(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/proto_helper.h"
enum TpuCoreTypeEnum { enum TpuCoreTypeEnum {
@ -44,35 +45,41 @@ struct CompilationCacheKeyProperty {
extern "C" { extern "C" {
// Returns the number of available TPU core count. // Returns the number of available TPU core count.
int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state, TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
TpuCoreTypeEnum tpu_core_type); const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
// Creates a unique compilation cache `key` used for `put` and `get` operations. // Creates a unique compilation cache `key` used for `put` and `get` operations.
// Returned buffer is heap-allocated and must be owned. // Returned buffer is heap-allocated and must be owned.
const char* TpuCompile_CreateCompilationCacheKey( TFTPU_CAPI_EXPORT const char* TpuCompile_CreateCompilationCacheKey(
CompilationCacheKeyProperty property); CompilationCacheKeyProperty property);
// Creates a guaranteed const fingerprint. Guarantee const is normally used in // Creates a guaranteed const fingerprint. Guarantee const is normally used in
// TPU inference to avoid re-copying unchanged variables onto the TPU device. // TPU inference to avoid re-copying unchanged variables onto the TPU device.
// It promises the value is identical for every execution in the same session // It promises the value is identical for every execution in the same session
// even if the actual value changes in later executions. // even if the actual value changes in later executions.
uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint, TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
const char* data, uint64_t fingerprint, const char* data, size_t size);
size_t size);
// Executes the computations using XLA TPU compiler and returns TPU programs // Executes the computations using XLA TPU compiler and returns TPU programs
// ready for execution. // ready for execution.
void TpuCompile_CompileAheadOfTime( TFTPU_CAPI_EXPORT void TpuCompile_CompileAheadOfTime(
TpuSerializedProto aot_compilation_request, TpuSerializedProto aot_compilation_request, XLA_TpuProgram** tpu_programs[],
XLA_TpuProgram** tpu_programs[],
size_t* count, SE_Status* status); size_t* count, SE_Status* status);
// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto. // Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
void TpuCompile_BuildXLADeviceAssignment( TFTPU_CAPI_EXPORT void TpuCompile_BuildXLADeviceAssignment(
TpuSerializedProto serialized_tpu_compile_metadata, TpuSerializedProto serialized_tpu_compile_metadata,
const XLA_TpuMeshState* mesh_state, const XLA_TpuMeshState* mesh_state,
TpuSerializedProto* serialized_device_assignment, SE_Status* status); TpuSerializedProto* serialized_device_assignment, SE_Status* status);
struct TfTpu_CompileApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAheadOfTime);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_BuildXLADeviceAssignment);
};
} // extern "C" } // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_

View File

@ -353,7 +353,7 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
return; return;
} }
LogAndExit(42); std::quick_exit(42);
} }
/* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes( /* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_config_c_api.h" #include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/proto_helper.h"
namespace tensorflow { namespace tensorflow {
@ -97,13 +98,14 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>( OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
rmgr, tpu::kTpuMeshCommonStateResourceName)); rmgr, tpu::kTpuMeshCommonStateResourceName));
ConfigureDistributedTpuOp_DoWork( tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
num_devices_per_host.size(), num_devices_per_host.data(), num_devices_per_host.size(), num_devices_per_host.data(),
&host_config_output_size, &host_config_output, status); &host_config_output_size, &host_config_output, status);
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(), auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
tpu::kTpuMeshCommonStateResourceName, OP_REQUIRES_OK(ctx,
tpu::TpuMeshStateInterface::Create())); rmgr->Create(rmgr->default_container(),
tpu::kTpuMeshCommonStateResourceName, tpu_mesh));
Tensor* ctx_output; Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
@ -112,7 +114,8 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
TpuConfigurationApi_FreeCharArray(host_config_output);
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
VLOG(1) << "ConfigureDistributedTpuOp done"; VLOG(1) << "ConfigureDistributedTpuOp done";
} }
@ -171,7 +174,7 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state)); OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
core::ScopedUnref mesh_state_unref(mesh_state); core::ScopedUnref mesh_state_unref(mesh_state);
WaitForDistributedTpuOp_DoWork( tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
num_hosts, num_devices_per_host, num_hosts, num_devices_per_host,
const_cast<const int32_t**>(mapping_arg.data()), mesh_state, const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
&tpu_topology_output_size, &tpu_topology_output, status); &tpu_topology_output_size, &tpu_topology_output, status);
@ -183,7 +186,7 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
TpuConfigurationApi_FreeCharArray(tpu_topology_output); tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
VLOG(1) << "WaitForDistributedTpuOp done"; VLOG(1) << "WaitForDistributedTpuOp done";
} }
@ -196,7 +199,7 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>( OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
GetTPUConfigResourceMgr(), GetTPUConfigResourceMgr(),
tpu::kTpuMeshCommonStateResourceName)); tpu::kTpuMeshCommonStateResourceName));
ShutdownDistributedTpuOp_DoWork(status); tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
@ -213,7 +216,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
int32_t* device_id_output; int32_t* device_id_output;
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
InitializeHostForDistributedTpuOp_DoWork( tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
tpu_host_config.size(), tpu_host_config.data(), tpu_host_config.size(), tpu_host_config.data(),
enable_whole_mesh_compilations_, &device_id_output_size, enable_whole_mesh_compilations_, &device_id_output_size,
&device_id_output, status); &device_id_output, status);
@ -230,7 +233,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
TpuConfigurationApi_FreeInt32Array(device_id_output); tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
VLOG(1) << "InitializeHostForDistributedTpuOp done"; VLOG(1) << "InitializeHostForDistributedTpuOp done";
} }
@ -242,7 +245,8 @@ void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
auto tpu_topology = ctx->input(0).scalar<tstring>()(); auto tpu_topology = ctx->input(0).scalar<tstring>()();
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status); tpu::ConfigApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
tpu_topology.data(), status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status); TF_DeleteStatus(status);
@ -257,7 +261,8 @@ void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
int32_t number_of_chips_output = 0; int32_t number_of_chips_output = 0;
DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status); tpu::ConfigApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
&number_of_chips_output, status);
Tensor* ctx_output; Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));

View File

@ -15,20 +15,29 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ #define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
#include "tensorflow/core/tpu/libtftpu.h"
typedef struct XLA_TpuMeshState XLA_TpuMeshState; typedef struct XLA_TpuMeshState XLA_TpuMeshState;
extern "C" { extern "C" {
// Creates a new TPU mesh state object. // Creates a new TPU mesh state object.
XLA_TpuMeshState* TpuMeshState_Create(); TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
// Deletes the given TPU `mesh_state` object. Once deleted the object is // Deletes the given TPU `mesh_state` object. Once deleted the object is
// unusable. // unusable.
void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
// Returns a pointer to an opaque mesh data structure used internally. // Returns a pointer to an opaque mesh data structure used internally.
void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state); TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
XLA_TpuMeshState* mesh_state);
} // extern "C" } // extern "C"
struct TfTpu_MeshStateApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
};
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
namespace tensorflow { namespace tensorflow {
@ -38,19 +39,19 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
~TpuMeshStateInterface() override { ~TpuMeshStateInterface() override {
if (mesh_state_ != nullptr) { if (mesh_state_ != nullptr) {
TpuMeshState_Free(mesh_state_); MeshStateApiFn()->TpuMeshState_FreeFn(mesh_state_);
} }
} }
static TpuMeshStateInterface* Create() { static TpuMeshStateInterface* Create() {
return new TpuMeshStateInterface(TpuMeshState_Create()); return new TpuMeshStateInterface(MeshStateApiFn()->TpuMeshState_CreateFn());
} }
const XLA_TpuMeshState* data() const { return mesh_state_; } const XLA_TpuMeshState* data() const { return mesh_state_; }
tensorflow::TpuMeshCommonState* mesh_common_state() const { tensorflow::TpuMeshCommonState* mesh_common_state() const {
return static_cast<tensorflow::TpuMeshCommonState*>( return static_cast<tensorflow::TpuMeshCommonState*>(
TpuMeshState_MeshCommonState(mesh_state_)); MeshStateApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
} }
// Returns whether we should include the device assignment as a static field // Returns whether we should include the device assignment as a static field
@ -62,8 +63,8 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
// Static device assignment enables XLA to perform certain optimization when // Static device assignment enables XLA to perform certain optimization when
// all cores are used in the replicated computation. // all cores are used in the replicated computation.
return metadata.num_cores_per_replica() * metadata.num_replicas() == return metadata.num_cores_per_replica() * metadata.num_replicas() ==
TpuTopology_AvailableCoreCount(mesh_state_, CompileApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
tpu_core_type); tpu_core_type);
} }
string DebugString() const override { return "TpuMeshStateInterface"; } string DebugString() const override { return "TpuMeshStateInterface"; }

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_ #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_ #define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/proto_helper.h"
typedef struct SE_Status SE_Status; typedef struct SE_Status SE_Status;
@ -32,4 +33,9 @@ void TpuCompile_ToTpuShapeRepresentation(
} // extern "C" } // extern "C"
struct TfTpu_UtilApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ToTpuShapeRepresentation);
};
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_ #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_

View File

@ -0,0 +1,166 @@
namespace {
tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
auto* config_fn = tensorflow::tpu::ConfigApiFn();
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
return tensorflow::Status::OK();
}
tensorflow::Status SetTpuMeshStateStructFns(void* library_handle) {
auto* mesh_state_fn = tensorflow::tpu::MeshStateApiFn();
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Create);
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Free);
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_MeshCommonState);
return tensorflow::Status::OK();
}
tensorflow::Status SetCompileStructFn(void* library_handle) {
auto* compile_fn = tensorflow::tpu::CompileApiFn();
TFTPU_SET_FN(compile_fn, TpuTopology_AvailableCoreCount);
TFTPU_SET_FN(compile_fn, TpuCompile_CreateCompilationCacheKey);
TFTPU_SET_FN(compile_fn, TpuCompile_CreateGuaranteedConstFingerprint);
TFTPU_SET_FN(compile_fn, TpuCompile_CompileAheadOfTime);
TFTPU_SET_FN(compile_fn, TpuCompile_BuildXLADeviceAssignment);
return tensorflow::Status::OK();
}
tensorflow::Status SetExecutorStructFn(void* library_handle) {
auto* executor_fn = tensorflow::tpu::ExecutorApiFn();
TFTPU_SET_FN(executor_fn, TpuPlatform_New);
TFTPU_SET_FN(executor_fn, TpuPlatform_Free);
TFTPU_SET_FN(executor_fn, TpuPlatform_Initialize);
TFTPU_SET_FN(executor_fn, TpuPlatform_Initialized);
TFTPU_SET_FN(executor_fn, TpuPlatform_GetExecutor);
TFTPU_SET_FN(executor_fn, TpuPlatform_Id);
TFTPU_SET_FN(executor_fn, TpuPlatform_VisibleDeviceCount);
TFTPU_SET_FN(executor_fn, TpuPlatform_TpuMemoryLimit);
TFTPU_SET_FN(executor_fn, TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
TFTPU_SET_FN(executor_fn, TpuExecutor_Init);
TFTPU_SET_FN(executor_fn, TpuExecutor_Free);
TFTPU_SET_FN(executor_fn, TpuExecutor_PlatformDeviceCount);
TFTPU_SET_FN(executor_fn, TpuExecutor_Allocate);
TFTPU_SET_FN(executor_fn, TpuExecutor_Deallocate);
TFTPU_SET_FN(executor_fn, TpuExecutor_GetAllocatorStats);
TFTPU_SET_FN(executor_fn, TpuExecutor_DeviceMemoryUsage);
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateStream);
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateStream);
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateStreamDependency);
TFTPU_SET_FN(executor_fn, TpuExecutor_GetStatus);
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateEvent);
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateEvent);
TFTPU_SET_FN(executor_fn, TpuExecutor_PollForEventStatus);
TFTPU_SET_FN(executor_fn, TpuExecutor_RecordEvent);
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForEvent);
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateTimer);
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateTimer);
TFTPU_SET_FN(executor_fn, TpuExecutor_StartTimer);
TFTPU_SET_FN(executor_fn, TpuExecutor_StopTimer);
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronousMemcpyToHost);
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronousMemcpyFromHost);
TFTPU_SET_FN(executor_fn, TpuExecutor_MemcpyToHost);
TFTPU_SET_FN(executor_fn, TpuExecutor_MemcpyFromHost);
TFTPU_SET_FN(executor_fn, TpuExecutor_EnqueueInfeed);
TFTPU_SET_FN(executor_fn, TpuExecutor_DequeueOutfeed);
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForInfeedReady);
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForOutfeedReady);
TFTPU_SET_FN(executor_fn, TpuExecutor_BlockHostUntilDone);
TFTPU_SET_FN(executor_fn, TpuExecutor_BlockUntilDoneOrFailed);
TFTPU_SET_FN(executor_fn, TpuExecutor_SyncAndForgetFailedStreams);
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronizeAllActivity);
TFTPU_SET_FN(executor_fn, TpuStream_New);
TFTPU_SET_FN(executor_fn, TpuStream_Free);
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
TFTPU_SET_FN(executor_fn, TpuStream_Status);
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
TFTPU_SET_FN(executor_fn, TpuEvent_New);
TFTPU_SET_FN(executor_fn, TpuEvent_Free);
TFTPU_SET_FN(executor_fn, TpuTimer_New);
TFTPU_SET_FN(executor_fn, TpuTimer_Free);
TFTPU_SET_FN(executor_fn, TpuTimer_Nanoseconds);
TFTPU_SET_FN(executor_fn, TpuTimer_Microseconds);
TFTPU_SET_FN(executor_fn, TpuStatus_New);
TFTPU_SET_FN(executor_fn, TpuStatus_Create);
TFTPU_SET_FN(executor_fn, TpuStatus_Free);
TFTPU_SET_FN(executor_fn, TpuStatus_Message);
TFTPU_SET_FN(executor_fn, TpuStatus_Code);
TFTPU_SET_FN(executor_fn, TpuStatus_Ok);
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Default);
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_SetOrdinal);
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Free);
TFTPU_SET_FN(executor_fn, TpuDeviceDescription_New);
TFTPU_SET_FN(executor_fn, TpuDeviceDescription_Free);
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateDeviceDescription);
TFTPU_SET_FN(executor_fn, TpuExecutor_NewDeviceOptions);
TFTPU_SET_FN(executor_fn, TpuExecutor_FreeDeviceOptions);
TFTPU_SET_FN(executor_fn, TpuExecutor_HostCallback);
TFTPU_SET_FN(executor_fn, TpuTransferManager_New);
TFTPU_SET_FN(executor_fn, TpuTransferManager_Free);
TFTPU_SET_FN(executor_fn, TpuTransferManager_PlatformId);
TFTPU_SET_FN(executor_fn, TpuTransferManager_HostShapeToDeviceShape);
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralToDeviceAsync);
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New);
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free);
return tensorflow::Status::OK();
}
tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) {
auto* node_context_fn = tensorflow::tpu::NodeContextApiFn();
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats);
TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost);
return tensorflow::Status::OK();
}
tensorflow::Status SetTpuUtilStructFns(void* library_handle) {
auto* util_fn = tensorflow::tpu::UtilApiFn();
TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled);
TFTPU_SET_FN(util_fn, TpuCompile_ToTpuShapeRepresentation);
return tensorflow::Status::OK();
}
tensorflow::Status InitializeTpuStructFns(void* library_handle) {
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuUtilStructFns(library_handle));
return tensorflow::Status::OK();
}
} // namespace

View File

@ -13,16 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// TODO(frankchn): Rename to `tpu_api_dlsym_initializer` or similar.
#include "tensorflow/core/tpu/tpu_library_loader.h" #include "tensorflow/core/tpu/tpu_library_loader.h"
#include <dlfcn.h> #include <dlfcn.h>
#define TFTPU_SET_FN(Struct, FnName) \
Struct->FnName##Fn = \
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#define TFTPU_SET_FN(Struct, FnName) \
Struct->FnName##Fn = \
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName)); \
if (!(Struct->FnName##Fn)) { \
LOG(ERROR) << #FnName " not available in this library."; \
}
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly // Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
// visible methods. // visible methods.
@ -30,28 +37,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tpu { namespace tpu {
Status SetTpuInitializeStructFns(void* library_handle) { #include "tensorflow/core/tpu/tpu_library_init_fns.inc"
auto* base_fn = InitializeApiFn();
TFTPU_SET_FN(base_fn, TfTpu_Initialize);
return Status::OK();
}
Status SetTpuConfigStructFns(void* library_handle) {
auto* config_fn = ConfigApiFn();
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
return Status::OK();
}
TfTpu_BaseFn* InitializeApiFn() { TfTpu_BaseFn* InitializeApiFn() {
static TfTpu_BaseFn base_fn; static TfTpu_BaseFn base_fn;
@ -63,19 +49,48 @@ TfTpu_ConfigApiFn* ConfigApiFn() {
return &config_api_fn; return &config_api_fn;
} }
TfTpu_MeshStateApiFn* MeshStateApiFn() {
static TfTpu_MeshStateApiFn mesh_state_api_fn;
return &mesh_state_api_fn;
}
TfTpu_CompileApiFn* CompileApiFn() {
static TfTpu_CompileApiFn compile_api_fn;
return &compile_api_fn;
}
TfTpu_ExecutorApiFn* ExecutorApiFn() {
static TfTpu_ExecutorApiFn executor_api_fn;
return &executor_api_fn;
}
TfTpu_NodeContextApiFn* NodeContextApiFn() {
static TfTpu_NodeContextApiFn node_context_api_fn;
return &node_context_api_fn;
}
TfTpu_UtilApiFn* UtilApiFn() {
static TfTpu_UtilApiFn util_api_fn;
return &util_api_fn;
}
Status InitializeTpuLibrary(void* library_handle) { Status InitializeTpuLibrary(void* library_handle) {
bool shared_object_loaded = true; bool shared_object_loaded = true;
if (library_handle == nullptr) { if (library_handle == nullptr) {
library_handle = dlopen(nullptr, RTLD_LAZY); library_handle = dlopen(nullptr, RTLD_NOW);
shared_object_loaded = false; shared_object_loaded = false;
} }
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle)); TF_RETURN_IF_ERROR(InitializeTpuStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
if (shared_object_loaded) { if (shared_object_loaded) {
// TODO(frankchn): Make initialization actually work
// Initialize TPU platform when the platform code is loaded from a library. // Initialize TPU platform when the platform code is loaded from a library.
InitializeApiFn()->TfTpu_InitializeFn(); // InitializeApiFn()->TfTpu_InitializeFn();
// We should only register the TPU platform when the library is loaded.
// TODO(frankchn): Resolve the circular dependency and register the platform
// RegisterTpuPlatform();
} }
return Status::OK(); return Status::OK();

View File

@ -17,8 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_ #define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h" #include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h" #include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
// LINT.IfChange // LINT.IfChange
namespace tensorflow { namespace tensorflow {
@ -26,10 +31,21 @@ namespace tpu {
Status InitializeTpuLibrary(void* library_handle); Status InitializeTpuLibrary(void* library_handle);
// TODO(frankchn): Separate out API functions from the loader.
TfTpu_BaseFn* InitializeApiFn(); TfTpu_BaseFn* InitializeApiFn();
TfTpu_ConfigApiFn* ConfigApiFn(); TfTpu_ConfigApiFn* ConfigApiFn();
TfTpu_MeshStateApiFn* MeshStateApiFn();
TfTpu_CompileApiFn* CompileApiFn();
TfTpu_ExecutorApiFn* ExecutorApiFn();
TfTpu_NodeContextApiFn* NodeContextApiFn();
TfTpu_UtilApiFn* UtilApiFn();
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow
// LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc) // LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc)

View File

@ -27,6 +27,16 @@ TfTpu_BaseFn* InitializeApiFn() { return nullptr; }
TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; } TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; }
TfTpu_MeshStateApiFn* MeshStateApiFn() { return nullptr; }
TfTpu_CompileApiFn* CompileApiFn() { return nullptr; }
TfTpu_ExecutorApiFn* ExecutorApiFn() { return nullptr; }
TfTpu_NodeContextApiFn* NodeContextApiFn() { return nullptr; }
TfTpu_UtilApiFn* UtilApiFn() { return nullptr; }
Status InitializeTpuLibrary(void* library_handle) { Status InitializeTpuLibrary(void* library_handle) {
return errors::Unimplemented( return errors::Unimplemented(
"Loading TPU library is not supported on Windows."); "Loading TPU library is not supported on Windows.");

View File

@ -11,20 +11,25 @@ package(
cc_library( cc_library(
name = "tpu_executor_c_api_hdrs", name = "tpu_executor_c_api_hdrs",
hdrs = ["tpu_executor_c_api.h"], hdrs = ["tpu_executor_c_api.h"],
visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/c:tf_attrtype", "//tensorflow/c:tf_attrtype",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/core/tpu/kernels:tpu_ops_common_c_api_hdrs", "//tensorflow/core/tpu/kernels:tpu_ops_common_c_api_hdrs",
], ],
alwayslink = True,
) )
cc_library( cc_library(
name = "tpu_node_context_c_api_hdrs", name = "tpu_node_context_c_api_hdrs",
hdrs = ["tpu_node_context_c_api.h"], hdrs = ["tpu_node_context_c_api.h"],
visibility = ["//visibility:public"],
deps = [ deps = [
":tpu_executor_c_api_hdrs", ":tpu_executor_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
], ],
alwayslink = True,
) )
cc_library( cc_library(
@ -65,6 +70,7 @@ cc_library(
":status_helper", ":status_helper",
":tpu_executor_c_api_hdrs", ":tpu_executor_c_api_hdrs",
":tpu_stream_interface", ":tpu_stream_interface",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor:stream", "//tensorflow/stream_executor:stream",
], ],
) )
@ -75,6 +81,7 @@ cc_library(
deps = [ deps = [
":tpu_executor_c_api_hdrs", ":tpu_executor_c_api_hdrs",
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor:stream", "//tensorflow/stream_executor:stream",
], ],
) )
@ -94,6 +101,7 @@ cc_library(
":tpu_timer", ":tpu_timer",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor:stream", "//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
@ -143,6 +151,7 @@ cc_library(
"//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -160,6 +169,7 @@ cc_library(
":tpu_platform_interface", ":tpu_platform_interface",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor:stream", "//tensorflow/stream_executor:stream",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
], ],
@ -191,6 +201,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor:stream", "//tensorflow/stream_executor:stream",
], ],
) )
@ -217,6 +228,7 @@ cc_library(
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"//tensorflow/stream_executor:multi_platform_manager", "//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor:stream_executor_headers",
"//tensorflow/stream_executor:stream_executor_pimpl",
], ],
) )

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
@ -33,63 +34,68 @@ namespace {
using ::stream_executor::port::Status; using ::stream_executor::port::Status;
} // namespace } // namespace
TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); } TpuExecutor::~TpuExecutor() {
tpu::ExecutorApiFn()->TpuExecutor_FreeFn(executor_);
}
Status TpuExecutor::Init(int device_ordinal, Status TpuExecutor::Init(int device_ordinal,
::stream_executor::DeviceOptions device_options) { ::stream_executor::DeviceOptions device_options) {
StatusHelper status; StatusHelper status;
SE_DeviceOptions* options = SE_DeviceOptions* options =
TpuExecutor_NewDeviceOptions(device_options.flags()); tpu::ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(
TpuExecutor_Init(executor_, device_ordinal, options, status.c_status); device_options.flags());
TpuExecutor_FreeDeviceOptions(options); tpu::ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options,
status.c_status);
tpu::ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options);
return status.status(); return status.status();
} }
int TpuExecutor::PlatformDeviceCount() { int TpuExecutor::PlatformDeviceCount() {
return TpuExecutor_PlatformDeviceCount(executor_); return tpu::ExecutorApiFn()->TpuExecutor_PlatformDeviceCountFn(executor_);
} }
void TpuExecutor::SyncAndForgetFailedStreams() { void TpuExecutor::SyncAndForgetFailedStreams() {
TpuExecutor_SyncAndForgetFailedStreams(executor_); tpu::ExecutorApiFn()->TpuExecutor_SyncAndForgetFailedStreamsFn(executor_);
} }
bool TpuExecutor::SynchronizeAllActivity() { bool TpuExecutor::SynchronizeAllActivity() {
return TpuExecutor_SynchronizeAllActivity(executor_); return tpu::ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_);
} }
Status TpuExecutor::BlockHostUntilDone(Stream* stream) { Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
StatusHelper status; StatusHelper status;
TpuExecutor_BlockHostUntilDone( tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn(
executor_, stream_map().at(stream->implementation()), status.c_status); executor_, stream_map().at(stream->implementation()), status.c_status);
return status.status(); return status.status();
} }
Status TpuExecutor::BlockUntilDoneOrFailed() { Status TpuExecutor::BlockUntilDoneOrFailed() {
StatusHelper status; StatusHelper status;
TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status); tpu::ExecutorApiFn()->TpuExecutor_BlockUntilDoneOrFailedFn(executor_,
status.c_status);
return status.status(); return status.status();
} }
Status TpuExecutor::GetStatus(Stream* stream) { Status TpuExecutor::GetStatus(Stream* stream) {
StatusHelper status; StatusHelper status;
TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()), tpu::ExecutorApiFn()->TpuExecutor_GetStatusFn(
status.c_status); executor_, stream_map().at(stream->implementation()), status.c_status);
return status.status(); return status.status();
} }
bool TpuExecutor::AllocateStream(Stream* stream) { bool TpuExecutor::AllocateStream(Stream* stream) {
return TpuExecutor_AllocateStream(executor_, return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
stream_map().at(stream->implementation())); executor_, stream_map().at(stream->implementation()));
} }
void TpuExecutor::DeallocateStream(Stream* stream) { void TpuExecutor::DeallocateStream(Stream* stream) {
TpuExecutor_DeallocateStream(executor_, tpu::ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(
stream_map().at(stream->implementation())); executor_, stream_map().at(stream->implementation()));
stream_map().erase(stream->implementation()); stream_map().erase(stream->implementation());
} }
bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
return TpuExecutor_CreateStreamDependency( return tpu::ExecutorApiFn()->TpuExecutor_CreateStreamDependencyFn(
executor_, stream_map().at(dependent->implementation()), executor_, stream_map().at(dependent->implementation()),
stream_map().at(other->implementation())); stream_map().at(other->implementation()));
} }
@ -104,15 +110,15 @@ bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
void TpuExecutor::DeallocateTimer(Timer* timer) {} void TpuExecutor::DeallocateTimer(Timer* timer) {}
bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) { bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
return TpuExecutor_StartTimer(executor_, return tpu::ExecutorApiFn()->TpuExecutor_StartTimerFn(
stream_map().at(stream->implementation()), executor_, stream_map().at(stream->implementation()),
timer_map_.at(timer->implementation())); timer_map_.at(timer->implementation()));
} }
bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) { bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
return TpuExecutor_StopTimer(executor_, return tpu::ExecutorApiFn()->TpuExecutor_StopTimerFn(
stream_map().at(stream->implementation()), executor_, stream_map().at(stream->implementation()),
timer_map_.at(timer->implementation())); timer_map_.at(timer->implementation()));
} }
stream_executor::Event::Status TpuExecutor::PollForEventStatus( stream_executor::Event::Status TpuExecutor::PollForEventStatus(
@ -148,7 +154,7 @@ Status TpuExecutor::WaitForEvent(Stream* stream,
// Called by Timer::Timer // Called by Timer::Timer
std::unique_ptr<::stream_executor::internal::TimerInterface> std::unique_ptr<::stream_executor::internal::TimerInterface>
TpuExecutor::GetTimerImplementation() { TpuExecutor::GetTimerImplementation() {
SE_Timer* tpu_timer = TpuTimer_New(executor_); SE_Timer* tpu_timer = tpu::ExecutorApiFn()->TpuTimer_NewFn(executor_);
auto ptr = absl::make_unique<TpuTimer>(tpu_timer); auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
timer_map_[ptr.get()] = tpu_timer; timer_map_[ptr.get()] = tpu_timer;
return ptr; return ptr;
@ -157,7 +163,7 @@ TpuExecutor::GetTimerImplementation() {
// Called by Stream::Stream // Called by Stream::Stream
std::unique_ptr<::stream_executor::internal::StreamInterface> std::unique_ptr<::stream_executor::internal::StreamInterface>
TpuExecutor::GetStreamImplementation() { TpuExecutor::GetStreamImplementation() {
SE_Stream* tpu_stream = TpuStream_New(executor_); SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
auto ptr = absl::make_unique<TpuStream>(tpu_stream); auto ptr = absl::make_unique<TpuStream>(tpu_stream);
stream_map()[ptr.get()] = tpu_stream; stream_map()[ptr.get()] = tpu_stream;
return ptr; return ptr;
@ -166,34 +172,35 @@ TpuExecutor::GetStreamImplementation() {
// Called by Event::Event // Called by Event::Event
std::unique_ptr<::stream_executor::internal::EventInterface> std::unique_ptr<::stream_executor::internal::EventInterface>
TpuExecutor::CreateEventImplementation() { TpuExecutor::CreateEventImplementation() {
SE_Event* tpu_event = TpuEvent_New(executor_); SE_Event* tpu_event = tpu::ExecutorApiFn()->TpuEvent_NewFn(executor_);
auto ptr = absl::make_unique<TpuEvent>(tpu_event); auto ptr = absl::make_unique<TpuEvent>(tpu_event);
event_map()[ptr.get()] = tpu_event; event_map()[ptr.get()] = tpu_event;
return ptr; return ptr;
} }
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) { DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
TpuExecutor_Allocate(executor_, size, memory_space); executor_, size, memory_space);
return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base); return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
} }
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) { void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
TpuExecutor_Deallocate(executor_, &se_base); tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
} }
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) { void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
TpuExecutor_Deallocate(executor_, &se_base); tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
} }
bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const { bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
int64_t _free; int64_t _free;
int64_t _total; int64_t _total;
if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) { if (tpu::ExecutorApiFn()->TpuExecutor_DeviceMemoryUsageFn(executor_, &_free,
&_total)) {
*free = _free; *free = _free;
*total = _total; *total = _total;
return true; return true;
@ -204,7 +211,8 @@ bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
absl::optional<stream_executor::AllocatorStats> absl::optional<stream_executor::AllocatorStats>
TpuExecutor::GetAllocatorStats() { TpuExecutor::GetAllocatorStats() {
SE_AllocatorStats c_stats; SE_AllocatorStats c_stats;
if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) { if (tpu::ExecutorApiFn()->TpuExecutor_GetAllocatorStatsFn(executor_,
&c_stats)) {
::stream_executor::AllocatorStats stats; ::stream_executor::AllocatorStats stats;
stats.num_allocs = c_stats.num_allocs; stats.num_allocs = c_stats.num_allocs;
stats.bytes_in_use = c_stats.bytes_in_use; stats.bytes_in_use = c_stats.bytes_in_use;
@ -226,31 +234,33 @@ TpuExecutor::GetAllocatorStats() {
Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) { Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
StatusHelper status; StatusHelper status;
TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index, tpu::ExecutorApiFn()->TpuExecutor_WaitForInfeedReadyFn(
status.c_status); executor_, infeed_queue_index, status.c_status);
return status.status(); return status.status();
} }
Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) { Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
StatusHelper status; StatusHelper status;
TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index, tpu::ExecutorApiFn()->TpuExecutor_WaitForOutfeedReadyFn(
status.c_status); executor_, outfeed_queue_index, status.c_status);
return status.status(); return status.status();
} }
void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index, void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
absl::Span<uint8> bytes, StatusCallback done) { absl::Span<uint8> bytes, StatusCallback done) {
StatusHelper status; StatusHelper status;
TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(), tpu::ExecutorApiFn()->TpuExecutor_DequeueOutfeedFn(
bytes.size(), status.c_status); executor_, outfeed_queue_index, bytes.data(), bytes.size(),
status.c_status);
done(status.status()); done(status.status());
} }
Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index, Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
absl::Span<const uint8> bytes) { absl::Span<const uint8> bytes) {
StatusHelper status; StatusHelper status;
TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(), tpu::ExecutorApiFn()->TpuExecutor_EnqueueInfeedFn(
bytes.size(), status.c_status); executor_, infeed_queue_index, bytes.data(), bytes.size(),
status.c_status);
return status.status(); return status.status();
} }
@ -259,9 +269,9 @@ bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
uint64 size) { uint64 size) {
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
return TpuExecutor_MemcpyToHost(executor_, return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
stream_map().at(stream->implementation()), executor_, stream_map().at(stream->implementation()), host_dst, &se_base,
host_dst, &se_base, size); size);
} }
bool TpuExecutor::Memcpy(Stream* stream, bool TpuExecutor::Memcpy(Stream* stream,
@ -269,9 +279,9 @@ bool TpuExecutor::Memcpy(Stream* stream,
const void* host_src, uint64 size) { const void* host_src, uint64 size) {
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
return TpuExecutor_MemcpyFromHost(executor_, return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
stream_map().at(stream->implementation()), executor_, stream_map().at(stream->implementation()), &se_base, host_src,
&se_base, host_src, size); size);
} }
Status TpuExecutor::SynchronousMemcpy( Status TpuExecutor::SynchronousMemcpy(
@ -280,8 +290,8 @@ Status TpuExecutor::SynchronousMemcpy(
StatusHelper status; StatusHelper status;
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size, tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
status.c_status); executor_, &se_base, host_src, size, status.c_status);
return status.status(); return status.status();
} }
@ -291,8 +301,8 @@ Status TpuExecutor::SynchronousMemcpy(
StatusHelper status; StatusHelper status;
SE_DeviceMemoryBase se_base = SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size, tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
status.c_status); executor_, host_dst, &se_base, size, status.c_status);
return status.status(); return status.status();
} }
@ -316,8 +326,8 @@ struct HostCallbackContext {
SE_Status* HostCallbackTrampoline(void* ctx) { SE_Status* HostCallbackTrampoline(void* ctx) {
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx); HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
Status status = host_ctx->callback(); Status status = host_ctx->callback();
SE_Status* c_status = SE_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
TpuStatus_Create(status.code(), status.error_message().c_str()); status.code(), status.error_message().c_str());
delete host_ctx; delete host_ctx;
return c_status; return c_status;
} }
@ -325,18 +335,21 @@ SE_Status* HostCallbackTrampoline(void* ctx) {
bool TpuExecutor::HostCallback(Stream* stream, bool TpuExecutor::HostCallback(Stream* stream,
std::function<Status()> callback) { std::function<Status()> callback) {
HostCallbackContext* ctx = new HostCallbackContext{callback}; HostCallbackContext* ctx = new HostCallbackContext{callback};
return TpuExecutor_HostCallback(executor_, return tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn(
stream_map().at(stream->implementation()), executor_, stream_map().at(stream->implementation()),
&HostCallbackTrampoline, ctx); &HostCallbackTrampoline, ctx);
} }
TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>> TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
TpuExecutor::CreateDeviceDescription() const { TpuExecutor::CreateDeviceDescription() const {
StatusHelper status; StatusHelper status;
SE_DeviceDescription* description = TpuDeviceDescription_New(); SE_DeviceDescription* description =
auto cleanup = tensorflow::gtl::MakeCleanup( tpu::ExecutorApiFn()->TpuDeviceDescription_NewFn();
[description]() { TpuDeviceDescription_Free(description); }); auto cleanup = tensorflow::gtl::MakeCleanup([description]() {
TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status); tpu::ExecutorApiFn()->TpuDeviceDescription_FreeFn(description);
});
tpu::ExecutorApiFn()->TpuExecutor_CreateDeviceDescriptionFn(
executor_, description, status.c_status);
if (status.status().ok()) { if (status.status().ok()) {
stream_executor::internal::DeviceDescriptionBuilder builder; stream_executor::internal::DeviceDescriptionBuilder builder;
CHECK_NE(description->device_vendor, nullptr); CHECK_NE(description->device_vendor, nullptr);

View File

@ -20,9 +20,9 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include "tensorflow/c/tf_attrtype.h" #include "tensorflow/c/tf_attrtype.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
typedef struct SE_Platform SE_Platform; typedef struct SE_Platform SE_Platform;
typedef struct SE_StreamExecutor SE_StreamExecutor; typedef struct SE_StreamExecutor SE_StreamExecutor;
@ -292,6 +292,96 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
XLA_ComputationPlacer* TpuComputationPlacer_New(); XLA_ComputationPlacer* TpuComputationPlacer_New();
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer); void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_New);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Initialize);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Initialized);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetExecutor);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Id);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_VisibleDeviceCount);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_TpuMemoryLimit);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Init);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PlatformDeviceCount);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Allocate);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Deallocate);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetAllocatorStats);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeviceMemoryUsage);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateStream);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateStream);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateStreamDependency);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetStatus);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateEvent);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateEvent);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PollForEventStatus);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_RecordEvent);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForEvent);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateTimer);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateTimer);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_StartTimer);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_StopTimer);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronousMemcpyToHost);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronousMemcpyFromHost);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_MemcpyToHost);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_MemcpyFromHost);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_EnqueueInfeed);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DequeueOutfeed);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForInfeedReady);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForOutfeedReady);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockHostUntilDone);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockUntilDoneOrFailed);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SyncAndForgetFailedStreams);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronizeAllActivity);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_New);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_New);
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Nanoseconds);
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Microseconds);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_New);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Message);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Ok);
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Default);
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_SetOrdinal);
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_New);
TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateDeviceDescription);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_NewDeviceOptions);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_FreeDeviceOptions);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_HostCallback);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_New);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_PlatformId);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_HostShapeToDeviceShape);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralToDeviceAsync);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New);
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);
};
} }
// extern "C" // extern "C"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
@ -32,15 +33,18 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
int device_ordinal) { int device_ordinal) {
StatusHelper status; StatusHelper status;
XLA_TpuNodeContext* node_context = XLA_TpuNodeContext* node_context =
TpuNodeContext_Create(device_ordinal, status.c_status); tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
status.c_status);
if (!status.status().ok()) { if (!status.status().ok()) {
TpuNodeContext_Free(node_context); tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
return status.status(); return status.status();
} }
return std::make_unique<TpuNodeContext>(device_ordinal, node_context); return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
} }
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); } TpuNodeContext::~TpuNodeContext() {
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
}
/* static */ /* static */
Status TpuNodeContext::Initialize(int device_ordinal) { Status TpuNodeContext::Initialize(int device_ordinal) {
@ -52,14 +56,14 @@ Status TpuNodeContext::Initialize(int device_ordinal) {
/* static */ /* static */
Status TpuNodeContext::StopChipHeartbeats() { Status TpuNodeContext::StopChipHeartbeats() {
StatusHelper status; StatusHelper status;
TpuNodeContext_StopChipHeartbeats(status.c_status); tpu::NodeContextApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
return status.status(); return status.status();
} }
/* static */ /* static */
Status TpuNodeContext::CloseTpuHost() { Status TpuNodeContext::CloseTpuHost() {
StatusHelper status; StatusHelper status;
TpuNodeContext_CloseTpuHost(status.c_status); tpu::NodeContextApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
return status.status(); return status.status();
} }

View File

@ -15,10 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_ #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
extern "C" {
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
SE_Status* status); SE_Status* status);
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
@ -28,4 +31,13 @@ void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
void TpuNodeContext_StopChipHeartbeats(SE_Status* status); void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
void TpuNodeContext_CloseTpuHost(SE_Status* status); void TpuNodeContext_CloseTpuHost(SE_Status* status);
} // extern "C"
struct TfTpu_NodeContextApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
};
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_platform.h" #include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/tpu/status_helper.h" #include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor.h" #include "tensorflow/stream_executor/tpu/tpu_executor.h"
@ -30,7 +31,9 @@ using Status = ::stream_executor::port::Status;
template <typename T> template <typename T>
using StatusOr = ::stream_executor::port::StatusOr<T>; using StatusOr = ::stream_executor::port::StatusOr<T>;
TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); } TpuPlatform::TpuPlatform() {
platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
}
TpuPlatform* TpuPlatform::GetRegisteredPlatform() { TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
return tpu_registered_platform; return tpu_registered_platform;
@ -53,8 +56,8 @@ Status TpuPlatform::Initialize(
i++; i++;
} }
TpuPlatform_Initialize(platform_, options_size, options_key, options_value, tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
status.c_status); platform_, options_size, options_key, options_value, status.c_status);
free(options_key); free(options_key);
free(options_value); free(options_value);
@ -62,10 +65,16 @@ Status TpuPlatform::Initialize(
return status.status(); return status.status();
} }
TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); } bool TpuPlatform::Initialized() const {
return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
}
TpuPlatform::~TpuPlatform() {
tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
}
int TpuPlatform::VisibleDeviceCount() const { int TpuPlatform::VisibleDeviceCount() const {
return TpuPlatform_VisibleDeviceCount(platform_); return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
} }
StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor( StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
@ -77,14 +86,16 @@ StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>> StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
TpuPlatform::GetUncachedExecutor( TpuPlatform::GetUncachedExecutor(
const ::stream_executor::StreamExecutorConfig& config) { const ::stream_executor::StreamExecutorConfig& config) {
SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default(); SE_StreamExecutorConfig* c_config =
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal); tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
config.ordinal);
StatusHelper status; StatusHelper status;
SE_StreamExecutor* executor = SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
TpuPlatform_GetExecutor(platform_, c_config, status.c_status); platform_, c_config, status.c_status);
TpuStreamExecutorConfig_Free(c_config); tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
if (!status.ok()) { if (!status.ok()) {
return status.status(); return status.status();
} }
@ -103,27 +114,24 @@ const std::string& TpuPlatform::Name() const {
} }
int64 TpuPlatform::TpuMemoryLimit() { int64 TpuPlatform::TpuMemoryLimit() {
return TpuPlatform_TpuMemoryLimit(platform_); return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
} }
bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() { bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
return TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(platform_); return tpu::ExecutorApiFn()
->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
}
void RegisterTpuPlatform() {
static bool tpu_platform_registered = false;
if (!tpu_platform_registered) {
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
std::unique_ptr<stream_executor::Platform> platform(
tensorflow::tpu_registered_platform);
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(platform)));
tpu_platform_registered = true;
}
} }
} // namespace tensorflow } // namespace tensorflow
void RegisterTpuPlatform() {
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
std::unique_ptr<stream_executor::Platform> platform(
tensorflow::tpu_registered_platform);
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(platform)));
}
REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
// Note that module initialization sequencing is not supported in the
// open-source project, so this will be a no-op there.
REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
tpu_platform);

View File

@ -60,9 +60,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
bool ShouldRegisterTpuDeviceToDeviceCopy() override; bool ShouldRegisterTpuDeviceToDeviceCopy() override;
bool Initialized() const override { bool Initialized() const override;
return TpuPlatform_Initialized(platform_);
}
Status Initialize( Status Initialize(
const std::map<std::string, std::string>& platform_options) override; const std::map<std::string, std::string>& platform_options) override;
@ -124,6 +122,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
EventMap event_map_; EventMap event_map_;
}; };
void RegisterTpuPlatform();
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_ #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/status_helper.h" #include "tensorflow/stream_executor/tpu/status_helper.h"
@ -27,23 +28,27 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
using Status = stream_executor::port::Status; using Status = stream_executor::port::Status;
explicit TpuStream(SE_Stream* stream) : stream_(stream) {} explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
~TpuStream() override { TpuStream_Free(stream_); } ~TpuStream() override {
tensorflow::tpu::ExecutorApiFn()->TpuStream_FreeFn(stream_);
}
bool IsSameSharedMemoryLocation( bool IsSameSharedMemoryLocation(
tensorflow::tpu::TpuStreamInterface* other) override { tensorflow::tpu::TpuStreamInterface* other) override {
return TpuStream_IsSameSharedMemoryLocation( return tensorflow::tpu::ExecutorApiFn()
stream_, static_cast<TpuStream*>(other)->stream_); ->TpuStream_IsSameSharedMemoryLocationFn(
stream_, static_cast<TpuStream*>(other)->stream_);
} }
Status EnqueueOnTpuDeviceSendRecvLocal( Status EnqueueOnTpuDeviceSendRecvLocal(
stream_executor::DeviceMemoryBase send_buffer, stream_executor::DeviceMemoryBase send_buffer,
stream_executor::DeviceMemoryBase recv_buffer) override { stream_executor::DeviceMemoryBase recv_buffer) override {
StatusHelper status; StatusHelper status;
TpuStream_TpuEnqueueOnDeviceSendRecvLocal( tensorflow::tpu::ExecutorApiFn()
stream_, ->TpuStream_TpuEnqueueOnDeviceSendRecvLocalFn(
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer), stream_,
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer), TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
status.c_status); TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
status.c_status);
return status.status(); return status.status();
} }
@ -54,7 +59,9 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
class TpuEvent : public ::stream_executor::internal::EventInterface { class TpuEvent : public ::stream_executor::internal::EventInterface {
public: public:
explicit TpuEvent(SE_Event* event) : event_(event) {} explicit TpuEvent(SE_Event* event) : event_(event) {}
~TpuEvent() override { TpuEvent_Free(event_); } ~TpuEvent() override {
tensorflow::tpu::ExecutorApiFn()->TpuEvent_FreeFn(event_);
}
private: private:
SE_Event* event_; SE_Event* event_;

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_ #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
@ -25,9 +26,15 @@ namespace tensorflow {
class TpuTimer : public ::stream_executor::internal::TimerInterface { class TpuTimer : public ::stream_executor::internal::TimerInterface {
public: public:
explicit TpuTimer(SE_Timer* timer) : timer_(timer) {} explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
~TpuTimer() override { TpuTimer_Free(timer_); } ~TpuTimer() override {
uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); } tensorflow::tpu::ExecutorApiFn()->TpuTimer_FreeFn(timer_);
uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); } }
uint64 Microseconds() const override {
return tensorflow::tpu::ExecutorApiFn()->TpuTimer_MicrosecondsFn(timer_);
}
uint64 Nanoseconds() const override {
return tensorflow::tpu::ExecutorApiFn()->TpuTimer_NanosecondsFn(timer_);
}
private: private:
SE_Timer* timer_; SE_Timer* timer_;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/proto_helper.h"
@ -29,10 +30,12 @@ namespace tensorflow {
using Status = stream_executor::port::Status; using Status = stream_executor::port::Status;
TpuTransferManager::TpuTransferManager() { TpuTransferManager::TpuTransferManager() {
manager_ = TpuTransferManager_New(); manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
} }
TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); } TpuTransferManager::~TpuTransferManager() {
tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
}
stream_executor::Platform::Id TpuTransferManager::PlatformId() const { stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
return TpuPlatform::kId; return TpuPlatform::kId;
@ -45,8 +48,8 @@ xla::Shape TpuTransferManager::HostShapeToDeviceShape(
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape); TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape, tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
&c_device_shape); manager_, &c_host_shape, &c_device_shape);
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape); xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
TpuConversions::CShapeCleanup(&c_host_shape); TpuConversions::CShapeCleanup(&c_host_shape);
TpuConversions::CShapeCleanup(&c_device_shape); TpuConversions::CShapeCleanup(&c_device_shape);
@ -66,7 +69,7 @@ Status TpuTransferManager::TransferLiteralToDeviceAsync(
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer, TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
&c_device_buffer); &c_device_buffer);
TpuTransferManager_TransferLiteralToDeviceAsync( tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
manager_, manager_,
TpuPlatform::GetRegisteredPlatform()->stream_map()->at( TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
stream->implementation()), stream->implementation()),
@ -112,7 +115,7 @@ void TpuTransferManager::TransferLiteralFromDevice(
XLA_Literal c_literal; XLA_Literal c_literal;
TpuConversions::XLALiteralToCLiteral(literal, &c_literal); TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
TpuTransferManager_TransferLiteralFromDevice( tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
manager_, manager_,
TpuPlatform::GetRegisteredPlatform()->stream_map()->at( TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
stream->implementation()), stream->implementation()),
@ -127,7 +130,8 @@ int64 TpuTransferManager::GetByteSizeRequirement(
TpuConversions::XlaShapeToCShape(shape, &c_shape); TpuConversions::XlaShapeToCShape(shape, &c_shape);
int64 size_in_bytes = int64 size_in_bytes =
TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape); tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
manager_, &c_shape);
TpuConversions::CShapeCleanup(&c_shape); TpuConversions::CShapeCleanup(&c_shape);
return size_in_bytes; return size_in_bytes;
@ -151,7 +155,7 @@ Status TpuTransferManager::WriteSingleTupleIndexTable(
region->payload()}; region->payload()};
StatusHelper status; StatusHelper status;
TpuTransferManager_WriteSingleTupleIndexTable( tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
manager_, manager_,
TpuPlatform::GetRegisteredPlatform()->stream_map()->at( TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
stream->implementation()), stream->implementation()),

View File

@ -2899,6 +2899,13 @@ def if_mlir(if_true, if_false = []):
"//conditions:default": if_false, "//conditions:default": if_false,
}) })
def if_tpu(if_true, if_false = []):
"""Shorthand for select()ing whether to build for TPUs."""
return select({
str(Label("//tensorflow:with_tpu_support")): if_true,
"//conditions:default": if_false,
})
def tfcompile_target_cpu(): def tfcompile_target_cpu():
return "" return ""