Add TPU configuration ops to default TensorFlow build
PiperOrigin-RevId: 317133514 Change-Id: I33bc6d7fdbba5915bd0d1291d4e086139c07eb14
This commit is contained in:
parent
d1157c976b
commit
cf00e559d7
4
.bazelrc
4
.bazelrc
|
@ -39,6 +39,7 @@
|
|||
#
|
||||
# Feature and Third party library support options:
|
||||
# xla: Build TF with XLA
|
||||
# tpu: Build TF with TPU support
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
|
@ -180,6 +181,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
|||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
# Config to build TPU backend
|
||||
build:tpu --define=with_tpu_support=true
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
|
|
|
@ -467,6 +467,13 @@ config_setting(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental TPU support
|
||||
config_setting(
|
||||
name = "with_tpu_support",
|
||||
values = {"define": "with_tpu_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Specifies via a config setting if this is a mobile build or not, makes
|
||||
# it easier to combine settings later.
|
||||
selects.config_setting_group(
|
||||
|
|
|
@ -72,6 +72,7 @@ load(
|
|||
"if_ios",
|
||||
"if_mobile",
|
||||
"if_not_windows",
|
||||
"if_tpu",
|
||||
"tf_android_core_proto_headers",
|
||||
"tf_cc_test",
|
||||
"tf_cc_test_mkl",
|
||||
|
@ -1093,6 +1094,8 @@ cc_library(
|
|||
]) + if_tensorrt([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
|
||||
]) + if_tpu([
|
||||
"//tensorflow/core/tpu/kernels",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
@ -103,6 +103,7 @@ cc_library(
|
|||
":libtftpu_header",
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -116,7 +117,20 @@ cc_library(
|
|||
deps = [
|
||||
":libtftpu_header",
|
||||
":tpu_config_c_api",
|
||||
":tpu_library_init_fns",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_library_init_fns",
|
||||
hdrs = ["tpu_library_init_fns.inc"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
|
@ -3,6 +3,10 @@ load(
|
|||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
|
@ -12,6 +16,12 @@ package(
|
|||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "kernels",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tpu_configuration_ops"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_common",
|
||||
srcs = ["tpu_compile_op_common.cc"],
|
||||
|
@ -50,7 +60,7 @@ cc_library(
|
|||
hdrs = ["tpu_compile_op_options.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
tf_kernel_library(
|
||||
name = "tpu_configuration_ops",
|
||||
srcs = ["tpu_configuration_ops.cc"],
|
||||
hdrs = ["tpu_configuration_ops.h"],
|
||||
|
@ -75,12 +85,13 @@ cc_library(
|
|||
name = "tpu_compile_c_api_hdrs",
|
||||
hdrs = ["tpu_compile_c_api.h"],
|
||||
deps = [
|
||||
":tpu_mesh_state_c_api",
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
":tpu_ops_common_c_api_hdrs",
|
||||
":tpu_program_c_api_hdrs",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
|
@ -197,8 +208,10 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_mesh_state_c_api",
|
||||
name = "tpu_mesh_state_c_api_hdrs",
|
||||
hdrs = ["tpu_mesh_state_c_api.h"],
|
||||
deps = ["//tensorflow/core/tpu:libtftpu_header"],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -207,12 +220,11 @@ cc_library(
|
|||
hdrs = ["tpu_mesh_state_interface.h"],
|
||||
deps = [
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_mesh_state_c_api",
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -371,13 +383,16 @@ cc_library(
|
|||
name = "tpu_util_c_api_hdrs",
|
||||
hdrs = ["tpu_util_c_api.h"],
|
||||
deps = [
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_ops_common_c_api_hdrs",
|
||||
hdrs = ["tpu_ops_common_c_api.h"],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -387,6 +402,7 @@ cc_library(
|
|||
":tpu_ops_common_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
enum TpuCoreTypeEnum {
|
||||
|
@ -44,35 +45,41 @@ struct CompilationCacheKeyProperty {
|
|||
extern "C" {
|
||||
|
||||
// Returns the number of available TPU core count.
|
||||
int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
|
||||
const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
|
||||
|
||||
// Creates a unique compilation cache `key` used for `put` and `get` operations.
|
||||
// Returned buffer is heap-allocated and must be owned.
|
||||
const char* TpuCompile_CreateCompilationCacheKey(
|
||||
TFTPU_CAPI_EXPORT const char* TpuCompile_CreateCompilationCacheKey(
|
||||
CompilationCacheKeyProperty property);
|
||||
|
||||
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
|
||||
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
|
||||
// It promises the value is identical for every execution in the same session
|
||||
// even if the actual value changes in later executions.
|
||||
uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint,
|
||||
const char* data,
|
||||
size_t size);
|
||||
TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
|
||||
uint64_t fingerprint, const char* data, size_t size);
|
||||
|
||||
// Executes the computations using XLA TPU compiler and returns TPU programs
|
||||
// ready for execution.
|
||||
void TpuCompile_CompileAheadOfTime(
|
||||
TpuSerializedProto aot_compilation_request,
|
||||
XLA_TpuProgram** tpu_programs[],
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_CompileAheadOfTime(
|
||||
TpuSerializedProto aot_compilation_request, XLA_TpuProgram** tpu_programs[],
|
||||
size_t* count, SE_Status* status);
|
||||
|
||||
// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
|
||||
void TpuCompile_BuildXLADeviceAssignment(
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_BuildXLADeviceAssignment(
|
||||
TpuSerializedProto serialized_tpu_compile_metadata,
|
||||
const XLA_TpuMeshState* mesh_state,
|
||||
TpuSerializedProto* serialized_device_assignment, SE_Status* status);
|
||||
|
||||
struct TfTpu_CompileApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAheadOfTime);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_BuildXLADeviceAssignment);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||
|
|
|
@ -353,7 +353,7 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
|
|||
return;
|
||||
}
|
||||
|
||||
LogAndExit(42);
|
||||
std::quick_exit(42);
|
||||
}
|
||||
|
||||
/* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
|
||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -97,13 +98,14 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
rmgr, tpu::kTpuMeshCommonStateResourceName));
|
||||
|
||||
ConfigureDistributedTpuOp_DoWork(
|
||||
tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
&host_config_output_size, &host_config_output, status);
|
||||
|
||||
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
|
||||
tpu::kTpuMeshCommonStateResourceName,
|
||||
tpu::TpuMeshStateInterface::Create()));
|
||||
auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
rmgr->Create(rmgr->default_container(),
|
||||
tpu::kTpuMeshCommonStateResourceName, tpu_mesh));
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
|
@ -112,7 +114,8 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
TpuConfigurationApi_FreeCharArray(host_config_output);
|
||||
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||
|
||||
VLOG(1) << "ConfigureDistributedTpuOp done";
|
||||
}
|
||||
|
@ -171,7 +174,7 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
||||
core::ScopedUnref mesh_state_unref(mesh_state);
|
||||
|
||||
WaitForDistributedTpuOp_DoWork(
|
||||
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||
num_hosts, num_devices_per_host,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
|
||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||
|
@ -183,7 +186,7 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
TpuConfigurationApi_FreeCharArray(tpu_topology_output);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
|
||||
|
||||
VLOG(1) << "WaitForDistributedTpuOp done";
|
||||
}
|
||||
|
@ -196,7 +199,7 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
GetTPUConfigResourceMgr(),
|
||||
tpu::kTpuMeshCommonStateResourceName));
|
||||
ShutdownDistributedTpuOp_DoWork(status);
|
||||
tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status);
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
|
@ -213,7 +216,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
int32_t* device_id_output;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
InitializeHostForDistributedTpuOp_DoWork(
|
||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||
&device_id_output, status);
|
||||
|
@ -230,7 +233,7 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
TpuConfigurationApi_FreeInt32Array(device_id_output);
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||
|
||||
VLOG(1) << "InitializeHostForDistributedTpuOp done";
|
||||
}
|
||||
|
@ -242,7 +245,8 @@ void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
|
|||
auto tpu_topology = ctx->input(0).scalar<tstring>()();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status);
|
||||
tpu::ConfigApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
|
||||
tpu_topology.data(), status);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
@ -257,7 +261,8 @@ void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
|
|||
TF_Status* status = TF_NewStatus();
|
||||
int32_t number_of_chips_output = 0;
|
||||
|
||||
DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status);
|
||||
tpu::ConfigApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
|
||||
&number_of_chips_output, status);
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
|
|
|
@ -15,20 +15,29 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
|
||||
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Creates a new TPU mesh state object.
|
||||
XLA_TpuMeshState* TpuMeshState_Create();
|
||||
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
|
||||
|
||||
// Deletes the given TPU `mesh_state` object. Once deleted the object is
|
||||
// unusable.
|
||||
void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
||||
TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
||||
|
||||
// Returns a pointer to an opaque mesh data structure used internally.
|
||||
void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state);
|
||||
TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
|
||||
XLA_TpuMeshState* mesh_state);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
struct TfTpu_MeshStateApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -38,19 +39,19 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
|||
|
||||
~TpuMeshStateInterface() override {
|
||||
if (mesh_state_ != nullptr) {
|
||||
TpuMeshState_Free(mesh_state_);
|
||||
MeshStateApiFn()->TpuMeshState_FreeFn(mesh_state_);
|
||||
}
|
||||
}
|
||||
|
||||
static TpuMeshStateInterface* Create() {
|
||||
return new TpuMeshStateInterface(TpuMeshState_Create());
|
||||
return new TpuMeshStateInterface(MeshStateApiFn()->TpuMeshState_CreateFn());
|
||||
}
|
||||
|
||||
const XLA_TpuMeshState* data() const { return mesh_state_; }
|
||||
|
||||
tensorflow::TpuMeshCommonState* mesh_common_state() const {
|
||||
return static_cast<tensorflow::TpuMeshCommonState*>(
|
||||
TpuMeshState_MeshCommonState(mesh_state_));
|
||||
MeshStateApiFn()->TpuMeshState_MeshCommonStateFn(mesh_state_));
|
||||
}
|
||||
|
||||
// Returns whether we should include the device assignment as a static field
|
||||
|
@ -62,7 +63,7 @@ class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
|||
// Static device assignment enables XLA to perform certain optimization when
|
||||
// all cores are used in the replicated computation.
|
||||
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
|
||||
TpuTopology_AvailableCoreCount(mesh_state_,
|
||||
CompileApiFn()->TpuTopology_AvailableCoreCountFn(mesh_state_,
|
||||
tpu_core_type);
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
typedef struct SE_Status SE_Status;
|
||||
|
@ -32,4 +33,9 @@ void TpuCompile_ToTpuShapeRepresentation(
|
|||
|
||||
} // extern "C"
|
||||
|
||||
struct TfTpu_UtilApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ToTpuShapeRepresentation);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_C_API_H_
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
namespace {
|
||||
|
||||
tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
||||
auto* config_fn = tensorflow::tpu::ConfigApiFn();
|
||||
|
||||
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetTpuMeshStateStructFns(void* library_handle) {
|
||||
auto* mesh_state_fn = tensorflow::tpu::MeshStateApiFn();
|
||||
|
||||
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Create);
|
||||
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_Free);
|
||||
TFTPU_SET_FN(mesh_state_fn, TpuMeshState_MeshCommonState);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetCompileStructFn(void* library_handle) {
|
||||
auto* compile_fn = tensorflow::tpu::CompileApiFn();
|
||||
|
||||
TFTPU_SET_FN(compile_fn, TpuTopology_AvailableCoreCount);
|
||||
TFTPU_SET_FN(compile_fn, TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_SET_FN(compile_fn, TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
TFTPU_SET_FN(compile_fn, TpuCompile_CompileAheadOfTime);
|
||||
TFTPU_SET_FN(compile_fn, TpuCompile_BuildXLADeviceAssignment);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
auto* executor_fn = tensorflow::tpu::ExecutorApiFn();
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_Initialize);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_Initialized);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_GetExecutor);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_Id);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_VisibleDeviceCount);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_TpuMemoryLimit);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_Init);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_PlatformDeviceCount);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_Allocate);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_Deallocate);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_GetAllocatorStats);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DeviceMemoryUsage);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateStream);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateStream);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateStreamDependency);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_GetStatus);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateEvent);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateEvent);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_PollForEventStatus);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_RecordEvent);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForEvent);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateTimer);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateTimer);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_StartTimer);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_StopTimer);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronousMemcpyToHost);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronousMemcpyFromHost);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_MemcpyToHost);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_MemcpyFromHost);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_EnqueueInfeed);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DequeueOutfeed);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForInfeedReady);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_WaitForOutfeedReady);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_BlockHostUntilDone);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_BlockUntilDoneOrFailed);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_SyncAndForgetFailedStreams);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_SynchronizeAllActivity);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_Stream);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_Status);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_IsSameSharedMemoryLocation);
|
||||
TFTPU_SET_FN(executor_fn, TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuEvent_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuEvent_Free);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuTimer_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuTimer_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuTimer_Nanoseconds);
|
||||
TFTPU_SET_FN(executor_fn, TpuTimer_Microseconds);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Create);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Message);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Code);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Ok);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Default);
|
||||
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_SetOrdinal);
|
||||
TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Free);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuDeviceDescription_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuDeviceDescription_Free);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateDeviceDescription);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_NewDeviceOptions);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_FreeDeviceOptions);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_HostCallback);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_PlatformId);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_HostShapeToDeviceShape);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralToDeviceAsync);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) {
|
||||
auto* node_context_fn = tensorflow::tpu::NodeContextApiFn();
|
||||
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetTpuUtilStructFns(void* library_handle) {
|
||||
auto* util_fn = tensorflow::tpu::UtilApiFn();
|
||||
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled);
|
||||
TFTPU_SET_FN(util_fn, TpuCompile_ToTpuShapeRepresentation);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status InitializeTpuStructFns(void* library_handle) {
|
||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuUtilStructFns(library_handle));
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
|
@ -13,16 +13,23 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TODO(frankchn): Rename to `tpu_api_dlsym_initializer` or similar.
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#define TFTPU_SET_FN(Struct, FnName) \
|
||||
Struct->FnName##Fn = \
|
||||
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
#define TFTPU_SET_FN(Struct, FnName) \
|
||||
Struct->FnName##Fn = \
|
||||
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName)); \
|
||||
if (!(Struct->FnName##Fn)) { \
|
||||
LOG(ERROR) << #FnName " not available in this library."; \
|
||||
}
|
||||
|
||||
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
|
||||
// visible methods.
|
||||
|
@ -30,28 +37,7 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
Status SetTpuInitializeStructFns(void* library_handle) {
|
||||
auto* base_fn = InitializeApiFn();
|
||||
|
||||
TFTPU_SET_FN(base_fn, TfTpu_Initialize);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetTpuConfigStructFns(void* library_handle) {
|
||||
auto* config_fn = ConfigApiFn();
|
||||
|
||||
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#include "tensorflow/core/tpu/tpu_library_init_fns.inc"
|
||||
|
||||
TfTpu_BaseFn* InitializeApiFn() {
|
||||
static TfTpu_BaseFn base_fn;
|
||||
|
@ -63,19 +49,48 @@ TfTpu_ConfigApiFn* ConfigApiFn() {
|
|||
return &config_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_MeshStateApiFn* MeshStateApiFn() {
|
||||
static TfTpu_MeshStateApiFn mesh_state_api_fn;
|
||||
return &mesh_state_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_CompileApiFn* CompileApiFn() {
|
||||
static TfTpu_CompileApiFn compile_api_fn;
|
||||
return &compile_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_ExecutorApiFn* ExecutorApiFn() {
|
||||
static TfTpu_ExecutorApiFn executor_api_fn;
|
||||
return &executor_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_NodeContextApiFn* NodeContextApiFn() {
|
||||
static TfTpu_NodeContextApiFn node_context_api_fn;
|
||||
return &node_context_api_fn;
|
||||
}
|
||||
|
||||
TfTpu_UtilApiFn* UtilApiFn() {
|
||||
static TfTpu_UtilApiFn util_api_fn;
|
||||
return &util_api_fn;
|
||||
}
|
||||
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
bool shared_object_loaded = true;
|
||||
if (library_handle == nullptr) {
|
||||
library_handle = dlopen(nullptr, RTLD_LAZY);
|
||||
library_handle = dlopen(nullptr, RTLD_NOW);
|
||||
shared_object_loaded = false;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(InitializeTpuStructFns(library_handle));
|
||||
|
||||
if (shared_object_loaded) {
|
||||
// TODO(frankchn): Make initialization actually work
|
||||
// Initialize TPU platform when the platform code is loaded from a library.
|
||||
InitializeApiFn()->TfTpu_InitializeFn();
|
||||
// InitializeApiFn()->TfTpu_InitializeFn();
|
||||
|
||||
// We should only register the TPU platform when the library is loaded.
|
||||
// TODO(frankchn): Resolve the circular dependency and register the platform
|
||||
// RegisterTpuPlatform();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -17,8 +17,13 @@ limitations under the License.
|
|||
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
|
||||
// LINT.IfChange
|
||||
namespace tensorflow {
|
||||
|
@ -26,10 +31,21 @@ namespace tpu {
|
|||
|
||||
Status InitializeTpuLibrary(void* library_handle);
|
||||
|
||||
// TODO(frankchn): Separate out API functions from the loader.
|
||||
TfTpu_BaseFn* InitializeApiFn();
|
||||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn();
|
||||
|
||||
TfTpu_MeshStateApiFn* MeshStateApiFn();
|
||||
|
||||
TfTpu_CompileApiFn* CompileApiFn();
|
||||
|
||||
TfTpu_ExecutorApiFn* ExecutorApiFn();
|
||||
|
||||
TfTpu_NodeContextApiFn* NodeContextApiFn();
|
||||
|
||||
TfTpu_UtilApiFn* UtilApiFn();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
// LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc)
|
||||
|
|
|
@ -27,6 +27,16 @@ TfTpu_BaseFn* InitializeApiFn() { return nullptr; }
|
|||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; }
|
||||
|
||||
TfTpu_MeshStateApiFn* MeshStateApiFn() { return nullptr; }
|
||||
|
||||
TfTpu_CompileApiFn* CompileApiFn() { return nullptr; }
|
||||
|
||||
TfTpu_ExecutorApiFn* ExecutorApiFn() { return nullptr; }
|
||||
|
||||
TfTpu_NodeContextApiFn* NodeContextApiFn() { return nullptr; }
|
||||
|
||||
TfTpu_UtilApiFn* UtilApiFn() { return nullptr; }
|
||||
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
return errors::Unimplemented(
|
||||
"Loading TPU library is not supported on Windows.");
|
||||
|
|
|
@ -11,20 +11,25 @@ package(
|
|||
cc_library(
|
||||
name = "tpu_executor_c_api_hdrs",
|
||||
hdrs = ["tpu_executor_c_api.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_attrtype",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/core/tpu/kernels:tpu_ops_common_c_api_hdrs",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_node_context_c_api_hdrs",
|
||||
hdrs = ["tpu_node_context_c_api.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -65,6 +70,7 @@ cc_library(
|
|||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_stream_interface",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
],
|
||||
)
|
||||
|
@ -75,6 +81,7 @@ cc_library(
|
|||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
],
|
||||
)
|
||||
|
@ -94,6 +101,7 @@ cc_library(
|
|||
":tpu_timer",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -143,6 +151,7 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -160,6 +169,7 @@ cc_library(
|
|||
":tpu_platform_interface",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
|
@ -191,6 +201,7 @@ cc_library(
|
|||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
],
|
||||
)
|
||||
|
@ -217,6 +228,7 @@ cc_library(
|
|||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
|
@ -33,63 +34,68 @@ namespace {
|
|||
using ::stream_executor::port::Status;
|
||||
} // namespace
|
||||
|
||||
TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); }
|
||||
TpuExecutor::~TpuExecutor() {
|
||||
tpu::ExecutorApiFn()->TpuExecutor_FreeFn(executor_);
|
||||
}
|
||||
|
||||
Status TpuExecutor::Init(int device_ordinal,
|
||||
::stream_executor::DeviceOptions device_options) {
|
||||
StatusHelper status;
|
||||
SE_DeviceOptions* options =
|
||||
TpuExecutor_NewDeviceOptions(device_options.flags());
|
||||
TpuExecutor_Init(executor_, device_ordinal, options, status.c_status);
|
||||
TpuExecutor_FreeDeviceOptions(options);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(
|
||||
device_options.flags());
|
||||
tpu::ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options,
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
int TpuExecutor::PlatformDeviceCount() {
|
||||
return TpuExecutor_PlatformDeviceCount(executor_);
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_PlatformDeviceCountFn(executor_);
|
||||
}
|
||||
|
||||
void TpuExecutor::SyncAndForgetFailedStreams() {
|
||||
TpuExecutor_SyncAndForgetFailedStreams(executor_);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_SyncAndForgetFailedStreamsFn(executor_);
|
||||
}
|
||||
|
||||
bool TpuExecutor::SynchronizeAllActivity() {
|
||||
return TpuExecutor_SynchronizeAllActivity(executor_);
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_);
|
||||
}
|
||||
|
||||
Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_BlockHostUntilDone(
|
||||
tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn(
|
||||
executor_, stream_map().at(stream->implementation()), status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::BlockUntilDoneOrFailed() {
|
||||
StatusHelper status;
|
||||
TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_BlockUntilDoneOrFailedFn(executor_,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::GetStatus(Stream* stream) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()),
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_GetStatusFn(
|
||||
executor_, stream_map().at(stream->implementation()), status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
bool TpuExecutor::AllocateStream(Stream* stream) {
|
||||
return TpuExecutor_AllocateStream(executor_,
|
||||
stream_map().at(stream->implementation()));
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
|
||||
executor_, stream_map().at(stream->implementation()));
|
||||
}
|
||||
|
||||
void TpuExecutor::DeallocateStream(Stream* stream) {
|
||||
TpuExecutor_DeallocateStream(executor_,
|
||||
stream_map().at(stream->implementation()));
|
||||
tpu::ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(
|
||||
executor_, stream_map().at(stream->implementation()));
|
||||
stream_map().erase(stream->implementation());
|
||||
}
|
||||
|
||||
bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
|
||||
return TpuExecutor_CreateStreamDependency(
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_CreateStreamDependencyFn(
|
||||
executor_, stream_map().at(dependent->implementation()),
|
||||
stream_map().at(other->implementation()));
|
||||
}
|
||||
|
@ -104,14 +110,14 @@ bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
|
|||
void TpuExecutor::DeallocateTimer(Timer* timer) {}
|
||||
|
||||
bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||
return TpuExecutor_StartTimer(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_StartTimerFn(
|
||||
executor_, stream_map().at(stream->implementation()),
|
||||
timer_map_.at(timer->implementation()));
|
||||
}
|
||||
|
||||
bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||
return TpuExecutor_StopTimer(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_StopTimerFn(
|
||||
executor_, stream_map().at(stream->implementation()),
|
||||
timer_map_.at(timer->implementation()));
|
||||
}
|
||||
|
||||
|
@ -148,7 +154,7 @@ Status TpuExecutor::WaitForEvent(Stream* stream,
|
|||
// Called by Timer::Timer
|
||||
std::unique_ptr<::stream_executor::internal::TimerInterface>
|
||||
TpuExecutor::GetTimerImplementation() {
|
||||
SE_Timer* tpu_timer = TpuTimer_New(executor_);
|
||||
SE_Timer* tpu_timer = tpu::ExecutorApiFn()->TpuTimer_NewFn(executor_);
|
||||
auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
|
||||
timer_map_[ptr.get()] = tpu_timer;
|
||||
return ptr;
|
||||
|
@ -157,7 +163,7 @@ TpuExecutor::GetTimerImplementation() {
|
|||
// Called by Stream::Stream
|
||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||
TpuExecutor::GetStreamImplementation() {
|
||||
SE_Stream* tpu_stream = TpuStream_New(executor_);
|
||||
SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
|
||||
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
|
||||
stream_map()[ptr.get()] = tpu_stream;
|
||||
return ptr;
|
||||
|
@ -166,34 +172,35 @@ TpuExecutor::GetStreamImplementation() {
|
|||
// Called by Event::Event
|
||||
std::unique_ptr<::stream_executor::internal::EventInterface>
|
||||
TpuExecutor::CreateEventImplementation() {
|
||||
SE_Event* tpu_event = TpuEvent_New(executor_);
|
||||
SE_Event* tpu_event = tpu::ExecutorApiFn()->TpuEvent_NewFn(executor_);
|
||||
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
|
||||
event_map()[ptr.get()] = tpu_event;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuExecutor_Allocate(executor_, size, memory_space);
|
||||
SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
|
||||
executor_, size, memory_space);
|
||||
return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
|
||||
}
|
||||
|
||||
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
|
||||
TpuExecutor_Deallocate(executor_, &se_base);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
|
||||
}
|
||||
|
||||
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
|
||||
TpuExecutor_Deallocate(executor_, &se_base);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
|
||||
}
|
||||
|
||||
bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
||||
int64_t _free;
|
||||
int64_t _total;
|
||||
if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) {
|
||||
if (tpu::ExecutorApiFn()->TpuExecutor_DeviceMemoryUsageFn(executor_, &_free,
|
||||
&_total)) {
|
||||
*free = _free;
|
||||
*total = _total;
|
||||
return true;
|
||||
|
@ -204,7 +211,8 @@ bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
|||
absl::optional<stream_executor::AllocatorStats>
|
||||
TpuExecutor::GetAllocatorStats() {
|
||||
SE_AllocatorStats c_stats;
|
||||
if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) {
|
||||
if (tpu::ExecutorApiFn()->TpuExecutor_GetAllocatorStatsFn(executor_,
|
||||
&c_stats)) {
|
||||
::stream_executor::AllocatorStats stats;
|
||||
stats.num_allocs = c_stats.num_allocs;
|
||||
stats.bytes_in_use = c_stats.bytes_in_use;
|
||||
|
@ -226,31 +234,33 @@ TpuExecutor::GetAllocatorStats() {
|
|||
|
||||
Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index,
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_WaitForInfeedReadyFn(
|
||||
executor_, infeed_queue_index, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index,
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_WaitForOutfeedReadyFn(
|
||||
executor_, outfeed_queue_index, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
|
||||
absl::Span<uint8> bytes, StatusCallback done) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(),
|
||||
bytes.size(), status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_DequeueOutfeedFn(
|
||||
executor_, outfeed_queue_index, bytes.data(), bytes.size(),
|
||||
status.c_status);
|
||||
done(status.status());
|
||||
}
|
||||
|
||||
Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
|
||||
absl::Span<const uint8> bytes) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(),
|
||||
bytes.size(), status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_EnqueueInfeedFn(
|
||||
executor_, infeed_queue_index, bytes.data(), bytes.size(),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
|
@ -259,9 +269,9 @@ bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
|
|||
uint64 size) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
||||
return TpuExecutor_MemcpyToHost(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
host_dst, &se_base, size);
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
|
||||
executor_, stream_map().at(stream->implementation()), host_dst, &se_base,
|
||||
size);
|
||||
}
|
||||
|
||||
bool TpuExecutor::Memcpy(Stream* stream,
|
||||
|
@ -269,9 +279,9 @@ bool TpuExecutor::Memcpy(Stream* stream,
|
|||
const void* host_src, uint64 size) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
||||
return TpuExecutor_MemcpyFromHost(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
&se_base, host_src, size);
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
|
||||
executor_, stream_map().at(stream->implementation()), &se_base, host_src,
|
||||
size);
|
||||
}
|
||||
|
||||
Status TpuExecutor::SynchronousMemcpy(
|
||||
|
@ -280,8 +290,8 @@ Status TpuExecutor::SynchronousMemcpy(
|
|||
StatusHelper status;
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
||||
TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size,
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
|
||||
executor_, &se_base, host_src, size, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
|
@ -291,8 +301,8 @@ Status TpuExecutor::SynchronousMemcpy(
|
|||
StatusHelper status;
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
||||
TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size,
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
|
||||
executor_, host_dst, &se_base, size, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
|
@ -316,8 +326,8 @@ struct HostCallbackContext {
|
|||
SE_Status* HostCallbackTrampoline(void* ctx) {
|
||||
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
|
||||
Status status = host_ctx->callback();
|
||||
SE_Status* c_status =
|
||||
TpuStatus_Create(status.code(), status.error_message().c_str());
|
||||
SE_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
|
||||
status.code(), status.error_message().c_str());
|
||||
delete host_ctx;
|
||||
return c_status;
|
||||
}
|
||||
|
@ -325,18 +335,21 @@ SE_Status* HostCallbackTrampoline(void* ctx) {
|
|||
bool TpuExecutor::HostCallback(Stream* stream,
|
||||
std::function<Status()> callback) {
|
||||
HostCallbackContext* ctx = new HostCallbackContext{callback};
|
||||
return TpuExecutor_HostCallback(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn(
|
||||
executor_, stream_map().at(stream->implementation()),
|
||||
&HostCallbackTrampoline, ctx);
|
||||
}
|
||||
|
||||
TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
|
||||
TpuExecutor::CreateDeviceDescription() const {
|
||||
StatusHelper status;
|
||||
SE_DeviceDescription* description = TpuDeviceDescription_New();
|
||||
auto cleanup = tensorflow::gtl::MakeCleanup(
|
||||
[description]() { TpuDeviceDescription_Free(description); });
|
||||
TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status);
|
||||
SE_DeviceDescription* description =
|
||||
tpu::ExecutorApiFn()->TpuDeviceDescription_NewFn();
|
||||
auto cleanup = tensorflow::gtl::MakeCleanup([description]() {
|
||||
tpu::ExecutorApiFn()->TpuDeviceDescription_FreeFn(description);
|
||||
});
|
||||
tpu::ExecutorApiFn()->TpuExecutor_CreateDeviceDescriptionFn(
|
||||
executor_, description, status.c_status);
|
||||
if (status.status().ok()) {
|
||||
stream_executor::internal::DeviceDescriptionBuilder builder;
|
||||
CHECK_NE(description->device_vendor, nullptr);
|
||||
|
|
|
@ -20,9 +20,9 @@ limitations under the License.
|
|||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_attrtype.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
|
||||
typedef struct SE_Platform SE_Platform;
|
||||
typedef struct SE_StreamExecutor SE_StreamExecutor;
|
||||
|
@ -292,6 +292,96 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
|
|||
|
||||
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
||||
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||
|
||||
struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Initialize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Initialized);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetExecutor);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_Id);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_VisibleDeviceCount);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_TpuMemoryLimit);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Init);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PlatformDeviceCount);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Allocate);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Deallocate);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetAllocatorStats);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeviceMemoryUsage);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateStreamDependency);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetStatus);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateEvent);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateEvent);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PollForEventStatus);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_RecordEvent);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForEvent);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateTimer);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateTimer);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_StartTimer);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_StopTimer);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronousMemcpyToHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronousMemcpyFromHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_MemcpyToHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_MemcpyFromHost);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_EnqueueInfeed);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DequeueOutfeed);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForInfeedReady);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_WaitForOutfeedReady);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockHostUntilDone);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_BlockUntilDoneOrFailed);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SyncAndForgetFailedStreams);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_SynchronizeAllActivity);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Stream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_Status);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_IsSameSharedMemoryLocation);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStream_TpuEnqueueOnDeviceSendRecvLocal);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuEvent_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Nanoseconds);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTimer_Microseconds);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Message);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Ok);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Default);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_SetOrdinal);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateDeviceDescription);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_NewDeviceOptions);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_FreeDeviceOptions);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_HostCallback);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_PlatformId);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_HostShapeToDeviceShape);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralToDeviceAsync);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);
|
||||
};
|
||||
}
|
||||
|
||||
// extern "C"
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
|
@ -32,15 +33,18 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
|
|||
int device_ordinal) {
|
||||
StatusHelper status;
|
||||
XLA_TpuNodeContext* node_context =
|
||||
TpuNodeContext_Create(device_ordinal, status.c_status);
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
|
||||
status.c_status);
|
||||
if (!status.status().ok()) {
|
||||
TpuNodeContext_Free(node_context);
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
|
||||
return status.status();
|
||||
}
|
||||
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
|
||||
}
|
||||
|
||||
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
|
||||
TpuNodeContext::~TpuNodeContext() {
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||
|
@ -52,14 +56,14 @@ Status TpuNodeContext::Initialize(int device_ordinal) {
|
|||
/* static */
|
||||
Status TpuNodeContext::StopChipHeartbeats() {
|
||||
StatusHelper status;
|
||||
TpuNodeContext_StopChipHeartbeats(status.c_status);
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_StopChipHeartbeatsFn(status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::CloseTpuHost() {
|
||||
StatusHelper status;
|
||||
TpuNodeContext_CloseTpuHost(status.c_status);
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_CloseTpuHostFn(status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
|
|
|
@ -15,10 +15,13 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
|
||||
|
||||
extern "C" {
|
||||
|
||||
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
||||
SE_Status* status);
|
||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||
|
@ -28,4 +31,13 @@ void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
|||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
struct TfTpu_NodeContextApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
|
@ -30,7 +31,9 @@ using Status = ::stream_executor::port::Status;
|
|||
template <typename T>
|
||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||
|
||||
TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); }
|
||||
TpuPlatform::TpuPlatform() {
|
||||
platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
|
||||
}
|
||||
|
||||
TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
|
||||
return tpu_registered_platform;
|
||||
|
@ -53,8 +56,8 @@ Status TpuPlatform::Initialize(
|
|||
i++;
|
||||
}
|
||||
|
||||
TpuPlatform_Initialize(platform_, options_size, options_key, options_value,
|
||||
status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
|
||||
platform_, options_size, options_key, options_value, status.c_status);
|
||||
|
||||
free(options_key);
|
||||
free(options_value);
|
||||
|
@ -62,10 +65,16 @@ Status TpuPlatform::Initialize(
|
|||
return status.status();
|
||||
}
|
||||
|
||||
TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); }
|
||||
bool TpuPlatform::Initialized() const {
|
||||
return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
|
||||
}
|
||||
|
||||
TpuPlatform::~TpuPlatform() {
|
||||
tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
|
||||
}
|
||||
|
||||
int TpuPlatform::VisibleDeviceCount() const {
|
||||
return TpuPlatform_VisibleDeviceCount(platform_);
|
||||
return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
|
||||
}
|
||||
|
||||
StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
|
||||
|
@ -77,14 +86,16 @@ StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
|
|||
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
|
||||
TpuPlatform::GetUncachedExecutor(
|
||||
const ::stream_executor::StreamExecutorConfig& config) {
|
||||
SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default();
|
||||
SE_StreamExecutorConfig* c_config =
|
||||
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
|
||||
|
||||
TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal);
|
||||
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
|
||||
config.ordinal);
|
||||
|
||||
StatusHelper status;
|
||||
SE_StreamExecutor* executor =
|
||||
TpuPlatform_GetExecutor(platform_, c_config, status.c_status);
|
||||
TpuStreamExecutorConfig_Free(c_config);
|
||||
SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
|
||||
platform_, c_config, status.c_status);
|
||||
tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
|
||||
if (!status.ok()) {
|
||||
return status.status();
|
||||
}
|
||||
|
@ -103,27 +114,24 @@ const std::string& TpuPlatform::Name() const {
|
|||
}
|
||||
|
||||
int64 TpuPlatform::TpuMemoryLimit() {
|
||||
return TpuPlatform_TpuMemoryLimit(platform_);
|
||||
return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
|
||||
}
|
||||
|
||||
bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
|
||||
return TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(platform_);
|
||||
return tpu::ExecutorApiFn()
|
||||
->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
void RegisterTpuPlatform() {
|
||||
static bool tpu_platform_registered = false;
|
||||
if (!tpu_platform_registered) {
|
||||
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
|
||||
std::unique_ptr<stream_executor::Platform> platform(
|
||||
tensorflow::tpu_registered_platform);
|
||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||
std::move(platform)));
|
||||
tpu_platform_registered = true;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
|
||||
|
||||
// Note that module initialization sequencing is not supported in the
|
||||
// open-source project, so this will be a no-op there.
|
||||
REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
|
||||
REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
|
||||
tpu_platform);
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -60,9 +60,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
|||
|
||||
bool ShouldRegisterTpuDeviceToDeviceCopy() override;
|
||||
|
||||
bool Initialized() const override {
|
||||
return TpuPlatform_Initialized(platform_);
|
||||
}
|
||||
bool Initialized() const override;
|
||||
|
||||
Status Initialize(
|
||||
const std::map<std::string, std::string>& platform_options) override;
|
||||
|
@ -124,6 +122,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
|||
EventMap event_map_;
|
||||
};
|
||||
|
||||
void RegisterTpuPlatform();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
|
@ -27,11 +28,14 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
|||
using Status = stream_executor::port::Status;
|
||||
|
||||
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
|
||||
~TpuStream() override { TpuStream_Free(stream_); }
|
||||
~TpuStream() override {
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuStream_FreeFn(stream_);
|
||||
}
|
||||
|
||||
bool IsSameSharedMemoryLocation(
|
||||
tensorflow::tpu::TpuStreamInterface* other) override {
|
||||
return TpuStream_IsSameSharedMemoryLocation(
|
||||
return tensorflow::tpu::ExecutorApiFn()
|
||||
->TpuStream_IsSameSharedMemoryLocationFn(
|
||||
stream_, static_cast<TpuStream*>(other)->stream_);
|
||||
}
|
||||
|
||||
|
@ -39,7 +43,8 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
|||
stream_executor::DeviceMemoryBase send_buffer,
|
||||
stream_executor::DeviceMemoryBase recv_buffer) override {
|
||||
StatusHelper status;
|
||||
TpuStream_TpuEnqueueOnDeviceSendRecvLocal(
|
||||
tensorflow::tpu::ExecutorApiFn()
|
||||
->TpuStream_TpuEnqueueOnDeviceSendRecvLocalFn(
|
||||
stream_,
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(send_buffer),
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(recv_buffer),
|
||||
|
@ -54,7 +59,9 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface {
|
|||
class TpuEvent : public ::stream_executor::internal::EventInterface {
|
||||
public:
|
||||
explicit TpuEvent(SE_Event* event) : event_(event) {}
|
||||
~TpuEvent() override { TpuEvent_Free(event_); }
|
||||
~TpuEvent() override {
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuEvent_FreeFn(event_);
|
||||
}
|
||||
|
||||
private:
|
||||
SE_Event* event_;
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
|
@ -25,9 +26,15 @@ namespace tensorflow {
|
|||
class TpuTimer : public ::stream_executor::internal::TimerInterface {
|
||||
public:
|
||||
explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
|
||||
~TpuTimer() override { TpuTimer_Free(timer_); }
|
||||
uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); }
|
||||
uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); }
|
||||
~TpuTimer() override {
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuTimer_FreeFn(timer_);
|
||||
}
|
||||
uint64 Microseconds() const override {
|
||||
return tensorflow::tpu::ExecutorApiFn()->TpuTimer_MicrosecondsFn(timer_);
|
||||
}
|
||||
uint64 Nanoseconds() const override {
|
||||
return tensorflow::tpu::ExecutorApiFn()->TpuTimer_NanosecondsFn(timer_);
|
||||
}
|
||||
|
||||
private:
|
||||
SE_Timer* timer_;
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
@ -29,10 +30,12 @@ namespace tensorflow {
|
|||
using Status = stream_executor::port::Status;
|
||||
|
||||
TpuTransferManager::TpuTransferManager() {
|
||||
manager_ = TpuTransferManager_New();
|
||||
manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
|
||||
}
|
||||
|
||||
TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); }
|
||||
TpuTransferManager::~TpuTransferManager() {
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
|
||||
}
|
||||
|
||||
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
|
||||
return TpuPlatform::kId;
|
||||
|
@ -45,8 +48,8 @@ xla::Shape TpuTransferManager::HostShapeToDeviceShape(
|
|||
|
||||
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
|
||||
|
||||
TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape,
|
||||
&c_device_shape);
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
|
||||
manager_, &c_host_shape, &c_device_shape);
|
||||
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
|
||||
TpuConversions::CShapeCleanup(&c_host_shape);
|
||||
TpuConversions::CShapeCleanup(&c_device_shape);
|
||||
|
@ -66,7 +69,7 @@ Status TpuTransferManager::TransferLiteralToDeviceAsync(
|
|||
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
|
||||
&c_device_buffer);
|
||||
|
||||
TpuTransferManager_TransferLiteralToDeviceAsync(
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
|
||||
manager_,
|
||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||
stream->implementation()),
|
||||
|
@ -112,7 +115,7 @@ void TpuTransferManager::TransferLiteralFromDevice(
|
|||
XLA_Literal c_literal;
|
||||
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
|
||||
|
||||
TpuTransferManager_TransferLiteralFromDevice(
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
|
||||
manager_,
|
||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||
stream->implementation()),
|
||||
|
@ -127,7 +130,8 @@ int64 TpuTransferManager::GetByteSizeRequirement(
|
|||
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
||||
|
||||
int64 size_in_bytes =
|
||||
TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape);
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
|
||||
manager_, &c_shape);
|
||||
|
||||
TpuConversions::CShapeCleanup(&c_shape);
|
||||
return size_in_bytes;
|
||||
|
@ -151,7 +155,7 @@ Status TpuTransferManager::WriteSingleTupleIndexTable(
|
|||
region->payload()};
|
||||
StatusHelper status;
|
||||
|
||||
TpuTransferManager_WriteSingleTupleIndexTable(
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
|
||||
manager_,
|
||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||
stream->implementation()),
|
||||
|
|
|
@ -2899,6 +2899,13 @@ def if_mlir(if_true, if_false = []):
|
|||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
def if_tpu(if_true, if_false = []):
|
||||
"""Shorthand for select()ing whether to build for TPUs."""
|
||||
return select({
|
||||
str(Label("//tensorflow:with_tpu_support")): if_true,
|
||||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
def tfcompile_target_cpu():
|
||||
return ""
|
||||
|
||||
|
|
Loading…
Reference in New Issue