Group together execution-related C API methods.
PiperOrigin-RevId: 321666252 Change-Id: Ife996fa7c3fca44546487198990a46fd61fdcf62
This commit is contained in:
parent
7a745bea24
commit
3cc65294f8
@ -125,6 +125,7 @@ cc_library(
|
|||||||
":libtftpu_header",
|
":libtftpu_header",
|
||||||
":tpu_config_c_api",
|
":tpu_config_c_api",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
@ -149,6 +150,7 @@ cc_library(
|
|||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
|
@ -523,6 +523,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tpu_program_c_api_hdrs",
|
":tpu_program_c_api_hdrs",
|
||||||
":tpu_util_c_api_hdrs",
|
":tpu_util_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu:libtftpu_header",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -26,13 +27,23 @@ typedef struct XLA_DeviceAssignment {
|
|||||||
size_t size;
|
size_t size;
|
||||||
} XLA_DeviceAssignment;
|
} XLA_DeviceAssignment;
|
||||||
|
|
||||||
void TpuExecutable_LoadProgramAndEnqueueToStream(
|
TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
|
||||||
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
|
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
|
||||||
size_t arguments_len, SE_DeviceMemoryBase* result,
|
size_t arguments_len, SE_DeviceMemoryBase* result,
|
||||||
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
|
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
|
||||||
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
|
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
|
||||||
SE_Status* status);
|
SE_Status* status);
|
||||||
|
|
||||||
|
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
|
||||||
|
XLA_Shape* host_shape, XLA_Shape* device_shape);
|
||||||
|
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
|
||||||
|
|
||||||
|
struct TfTpu_ExecuteApiFn {
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
|
||||||
|
};
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_
|
||||||
|
@ -38,6 +38,11 @@ TfTpu_CompileApiFn* CompileApiFn() {
|
|||||||
return &compile_api_fn;
|
return &compile_api_fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfTpu_ExecuteApiFn* ExecuteApiFn() {
|
||||||
|
static TfTpu_ExecuteApiFn execute_api_fn;
|
||||||
|
return &execute_api_fn;
|
||||||
|
}
|
||||||
|
|
||||||
TfTpu_TpuProgramApiFn* TpuProgramApiFn() {
|
TfTpu_TpuProgramApiFn* TpuProgramApiFn() {
|
||||||
static TfTpu_TpuProgramApiFn tpu_program_api_fn;
|
static TfTpu_TpuProgramApiFn tpu_program_api_fn;
|
||||||
return &tpu_program_api_fn;
|
return &tpu_program_api_fn;
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_TPU_TPU_API_H_
|
#define TENSORFLOW_CORE_TPU_TPU_API_H_
|
||||||
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||||
#include "tensorflow/core/tpu/libtftpu.h"
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
@ -35,6 +36,8 @@ TfTpu_MeshStateApiFn* MeshStateApiFn();
|
|||||||
|
|
||||||
TfTpu_CompileApiFn* CompileApiFn();
|
TfTpu_CompileApiFn* CompileApiFn();
|
||||||
|
|
||||||
|
TfTpu_ExecuteApiFn* ExecuteApiFn();
|
||||||
|
|
||||||
TfTpu_TpuProgramApiFn* TpuProgramApiFn();
|
TfTpu_TpuProgramApiFn* TpuProgramApiFn();
|
||||||
|
|
||||||
TfTpu_ExecutorApiFn* ExecutorApiFn();
|
TfTpu_ExecutorApiFn* ExecutorApiFn();
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||||
#include "tensorflow/core/tpu/libtftpu.h"
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
|
@ -37,6 +37,16 @@ tensorflow::Status SetCompileStructFn(void* library_handle) {
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::Status SetExecuteStructFn(void* library_handle) {
|
||||||
|
auto* execute_fn = tensorflow::tpu::ExecuteApiFn();
|
||||||
|
|
||||||
|
TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
|
||||||
|
TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape);
|
||||||
|
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize);
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
||||||
auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn();
|
auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn();
|
||||||
|
|
||||||
@ -145,9 +155,6 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
|
||||||
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
|
||||||
|
|
||||||
TFTPU_SET_FN(executor_fn, HardwareLayout_HostShapeToDeviceShape);
|
|
||||||
TFTPU_SET_FN(executor_fn, HardwareLayout_ShapeSize);
|
|
||||||
|
|
||||||
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New);
|
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New);
|
||||||
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free);
|
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free);
|
||||||
|
|
||||||
@ -197,6 +204,7 @@ tensorflow::Status InitializeTpuStructFns(void* library_handle) {
|
|||||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||||
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
|
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
|
||||||
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
|
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
|
||||||
|
TF_RETURN_IF_ERROR(SetExecuteStructFn(library_handle));
|
||||||
TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle));
|
TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle));
|
||||||
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
|
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
|
||||||
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
|
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||||
|
|
||||||
@ -80,10 +79,11 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
|
|||||||
run_options.run_options().stream()->implementation());
|
run_options.run_options().stream()->implementation());
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
|
|
||||||
TpuExecutable_LoadProgramAndEnqueueToStream(
|
tensorflow::tpu::ExecuteApiFn()
|
||||||
core_program_, arguments_bases, arguments.size(), &result_base,
|
->TpuExecutable_LoadProgramAndEnqueueToStreamFn(
|
||||||
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
|
core_program_, arguments_bases, arguments.size(), &result_base,
|
||||||
rng_seed, &c_dev_assign, stream, status.c_status);
|
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
|
||||||
|
rng_seed, &c_dev_assign, stream, status.c_status);
|
||||||
|
|
||||||
if (dev_assign != nullptr) {
|
if (dev_assign != nullptr) {
|
||||||
stream_executor::tpu::SerializedProto_Free(dev_assign_serialized);
|
stream_executor::tpu::SerializedProto_Free(dev_assign_serialized);
|
||||||
@ -96,7 +96,7 @@ Shape TpuExecutable::HostShapeToDeviceShape(const Shape& host_shape) {
|
|||||||
XLA_Shape c_host_shape;
|
XLA_Shape c_host_shape;
|
||||||
XLA_Shape c_device_shape;
|
XLA_Shape c_device_shape;
|
||||||
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
|
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
|
||||||
tensorflow::tpu::ExecutorApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
|
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
|
||||||
&c_host_shape, &c_device_shape);
|
&c_host_shape, &c_device_shape);
|
||||||
Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
|
Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
|
||||||
TpuConversions::CShapeCleanup(&c_host_shape);
|
TpuConversions::CShapeCleanup(&c_host_shape);
|
||||||
@ -108,7 +108,7 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
|
|||||||
XLA_Shape c_shape;
|
XLA_Shape c_shape;
|
||||||
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
||||||
int64 size =
|
int64 size =
|
||||||
tensorflow::tpu::ExecutorApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
|
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
|
||||||
TpuConversions::CShapeCleanup(&c_shape);
|
TpuConversions::CShapeCleanup(&c_shape);
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
@ -290,10 +290,6 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
|
|||||||
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
||||||
SE_DeviceMemoryBase* region, SE_Status* status);
|
SE_DeviceMemoryBase* region, SE_Status* status);
|
||||||
|
|
||||||
void HardwareLayout_HostShapeToDeviceShape(XLA_Shape* host_shape,
|
|
||||||
XLA_Shape* device_shape);
|
|
||||||
int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
|
|
||||||
|
|
||||||
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
||||||
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||||
|
|
||||||
@ -401,9 +397,6 @@ struct TfTpu_ExecutorApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
|
||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
|
|
||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New);
|
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user