From 3cc65294f89819163fc1e4d1e96d6807cfc33afc Mon Sep 17 00:00:00 2001 From: Wenhao Jia Date: Thu, 16 Jul 2020 16:07:28 -0700 Subject: [PATCH] Group together execution-related C API methods. PiperOrigin-RevId: 321666252 Change-Id: Ife996fa7c3fca44546487198990a46fd61fdcf62 --- tensorflow/core/tpu/BUILD | 2 ++ tensorflow/core/tpu/kernels/BUILD | 1 + tensorflow/core/tpu/kernels/tpu_execute_c_api.h | 13 ++++++++++++- tensorflow/core/tpu/tpu_api.cc | 5 +++++ tensorflow/core/tpu/tpu_api.h | 3 +++ tensorflow/core/tpu/tpu_api_dlsym_initializer.h | 1 + tensorflow/core/tpu/tpu_library_init_fns.inc | 14 +++++++++++--- tensorflow/stream_executor/tpu/tpu_executable.cc | 14 +++++++------- .../stream_executor/tpu/tpu_executor_c_api.h | 7 ------- 9 files changed, 42 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 7639cacc378..f9031b440f9 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -125,6 +125,7 @@ cc_library( ":libtftpu_header", ":tpu_config_c_api", "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", + "//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs", "//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs", "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", @@ -149,6 +150,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration", "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", + "//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs", "//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs", "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 89a36ed9ae4..7a6160a2963 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -523,6 +523,7 @@ cc_library( deps = [ ":tpu_program_c_api_hdrs", ":tpu_util_c_api_hdrs", + "//tensorflow/core/tpu:libtftpu_header", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", ], ) diff --git a/tensorflow/core/tpu/kernels/tpu_execute_c_api.h b/tensorflow/core/tpu/kernels/tpu_execute_c_api.h index db73af76efd..38a550444a9 100644 --- a/tensorflow/core/tpu/kernels/tpu_execute_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_execute_c_api.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_program_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" +#include "tensorflow/core/tpu/libtftpu.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" extern "C" { @@ -26,13 +27,23 @@ typedef struct XLA_DeviceAssignment { size_t size; } XLA_DeviceAssignment; -void TpuExecutable_LoadProgramAndEnqueueToStream( +TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments, size_t arguments_len, SE_DeviceMemoryBase* result, SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed, XLA_DeviceAssignment* device_assignment, SE_Stream* stream, SE_Status* status); +TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( + XLA_Shape* host_shape, XLA_Shape* device_shape); +TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); + +struct TfTpu_ExecuteApiFn { + TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); + TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); +}; + } // extern "C" #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_ diff --git a/tensorflow/core/tpu/tpu_api.cc b/tensorflow/core/tpu/tpu_api.cc index 3ce7626de2b..cd6ca80e4e7 100644 --- a/tensorflow/core/tpu/tpu_api.cc +++ b/tensorflow/core/tpu/tpu_api.cc @@ -38,6 +38,11 @@ TfTpu_CompileApiFn* CompileApiFn() { return &compile_api_fn; } +TfTpu_ExecuteApiFn* ExecuteApiFn() { + static TfTpu_ExecuteApiFn execute_api_fn; + return &execute_api_fn; +} + TfTpu_TpuProgramApiFn* TpuProgramApiFn() { static TfTpu_TpuProgramApiFn tpu_program_api_fn; return &tpu_program_api_fn; diff --git a/tensorflow/core/tpu/tpu_api.h b/tensorflow/core/tpu/tpu_api.h index 3467f82a180..b6edbfd14bb 100644 --- a/tensorflow/core/tpu/tpu_api.h +++ b/tensorflow/core/tpu/tpu_api.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_TPU_API_H_ #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/libtftpu.h" @@ -35,6 +36,8 @@ TfTpu_MeshStateApiFn* MeshStateApiFn(); TfTpu_CompileApiFn* CompileApiFn(); +TfTpu_ExecuteApiFn* ExecuteApiFn(); + TfTpu_TpuProgramApiFn* TpuProgramApiFn(); TfTpu_ExecutorApiFn* ExecutorApiFn(); diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.h b/tensorflow/core/tpu/tpu_api_dlsym_initializer.h index 257fa25ad37..1126e132264 100644 --- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.h +++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" #include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/libtftpu.h" diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 6737ae42570..7a7c6ecad30 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -37,6 +37,16 @@ tensorflow::Status SetCompileStructFn(void* library_handle) { return tensorflow::Status::OK(); } +tensorflow::Status SetExecuteStructFn(void* library_handle) { + auto* execute_fn = tensorflow::tpu::ExecuteApiFn(); + + TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream); + TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape); + TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize); + + return tensorflow::Status::OK(); +} + tensorflow::Status SetTpuProgramStructFn(void* library_handle) { auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn(); @@ -145,9 +155,6 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) { TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement); TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable); - TFTPU_SET_FN(executor_fn, HardwareLayout_HostShapeToDeviceShape); - TFTPU_SET_FN(executor_fn, HardwareLayout_ShapeSize); - TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New); TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free); @@ -197,6 +204,7 @@ tensorflow::Status InitializeTpuStructFns(void* library_handle) { TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle)); TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle)); TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle)); + TF_RETURN_IF_ERROR(SetExecuteStructFn(library_handle)); TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle)); TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle)); TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle)); diff --git a/tensorflow/stream_executor/tpu/tpu_executable.cc b/tensorflow/stream_executor/tpu/tpu_executable.cc index f6ded8415c1..e8ff3a54db8 100644 --- a/tensorflow/stream_executor/tpu/tpu_executable.cc +++ b/tensorflow/stream_executor/tpu/tpu_executable.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" #include "tensorflow/stream_executor/tpu/status_helper.h" -#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_platform.h" #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" @@ -80,10 +79,11 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream( run_options.run_options().stream()->implementation()); StatusHelper status; - TpuExecutable_LoadProgramAndEnqueueToStream( - core_program_, arguments_bases, arguments.size(), &result_base, - (cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr), - rng_seed, &c_dev_assign, stream, status.c_status); + tensorflow::tpu::ExecuteApiFn() + ->TpuExecutable_LoadProgramAndEnqueueToStreamFn( + core_program_, arguments_bases, arguments.size(), &result_base, + (cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr), + rng_seed, &c_dev_assign, stream, status.c_status); if (dev_assign != nullptr) { stream_executor::tpu::SerializedProto_Free(dev_assign_serialized); @@ -96,7 +96,7 @@ Shape TpuExecutable::HostShapeToDeviceShape(const Shape& host_shape) { XLA_Shape c_host_shape; XLA_Shape c_device_shape; TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape); - tensorflow::tpu::ExecutorApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( + tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn( &c_host_shape, &c_device_shape); Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape); TpuConversions::CShapeCleanup(&c_host_shape); @@ -108,7 +108,7 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) { XLA_Shape c_shape; TpuConversions::XlaShapeToCShape(shape, &c_shape); int64 size = - tensorflow::tpu::ExecutorApiFn()->HardwareLayout_ShapeSizeFn(&c_shape); + tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeFn(&c_shape); TpuConversions::CShapeCleanup(&c_shape); return size; } diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h index 1530c00e621..e99151c5dc3 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -290,10 +290,6 @@ void TpuTransferManager_WriteSingleTupleIndexTable( SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape, SE_DeviceMemoryBase* region, SE_Status* status); -void HardwareLayout_HostShapeToDeviceShape(XLA_Shape* host_shape, - XLA_Shape* device_shape); -int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); - XLA_ComputationPlacer* TpuComputationPlacer_New(); void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer); @@ -401,9 +397,6 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement); TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable); - TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); - TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); - TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New); TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);