Group together execution-related C API methods.

PiperOrigin-RevId: 321666252
Change-Id: Ife996fa7c3fca44546487198990a46fd61fdcf62
This commit is contained in:
Wenhao Jia 2020-07-16 16:07:28 -07:00 committed by TensorFlower Gardener
parent 7a745bea24
commit 3cc65294f8
9 changed files with 42 additions and 18 deletions

View File

@ -125,6 +125,7 @@ cc_library(
":libtftpu_header",
":tpu_config_c_api",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
@ -149,6 +150,7 @@ cc_library(
"//tensorflow/core/platform:status",
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_execute_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",

View File

@ -523,6 +523,7 @@ cc_library(
deps = [
":tpu_program_c_api_hdrs",
":tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
],
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
extern "C" {
@ -26,13 +27,23 @@ typedef struct XLA_DeviceAssignment {
size_t size;
} XLA_DeviceAssignment;
void TpuExecutable_LoadProgramAndEnqueueToStream(
TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
const XLA_TpuProgram* program, SE_DeviceMemoryBase* arguments,
size_t arguments_len, SE_DeviceMemoryBase* result,
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
SE_Status* status);
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
XLA_Shape* host_shape, XLA_Shape* device_shape);
TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
struct TfTpu_ExecuteApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
};
} // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_C_API_H_

View File

@ -38,6 +38,11 @@ TfTpu_CompileApiFn* CompileApiFn() {
return &compile_api_fn;
}
TfTpu_ExecuteApiFn* ExecuteApiFn() {
static TfTpu_ExecuteApiFn execute_api_fn;
return &execute_api_fn;
}
TfTpu_TpuProgramApiFn* TpuProgramApiFn() {
static TfTpu_TpuProgramApiFn tpu_program_api_fn;
return &tpu_program_api_fn;

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_TPU_API_H_
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
@ -35,6 +36,8 @@ TfTpu_MeshStateApiFn* MeshStateApiFn();
TfTpu_CompileApiFn* CompileApiFn();
TfTpu_ExecuteApiFn* ExecuteApiFn();
TfTpu_TpuProgramApiFn* TpuProgramApiFn();
TfTpu_ExecutorApiFn* ExecutorApiFn();

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_execute_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"

View File

@ -37,6 +37,16 @@ tensorflow::Status SetCompileStructFn(void* library_handle) {
return tensorflow::Status::OK();
}
tensorflow::Status SetExecuteStructFn(void* library_handle) {
auto* execute_fn = tensorflow::tpu::ExecuteApiFn();
TFTPU_SET_FN(execute_fn, TpuExecutable_LoadProgramAndEnqueueToStream);
TFTPU_SET_FN(execute_fn, HardwareLayout_HostShapeToDeviceShape);
TFTPU_SET_FN(execute_fn, HardwareLayout_ShapeSize);
return tensorflow::Status::OK();
}
tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
auto* tpu_program_fn = tensorflow::tpu::TpuProgramApiFn();
@ -145,9 +155,6 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
TFTPU_SET_FN(executor_fn, HardwareLayout_HostShapeToDeviceShape);
TFTPU_SET_FN(executor_fn, HardwareLayout_ShapeSize);
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_New);
TFTPU_SET_FN(executor_fn, TpuComputationPlacer_Free);
@ -197,6 +204,7 @@ tensorflow::Status InitializeTpuStructFns(void* library_handle) {
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuMeshStateStructFns(library_handle));
TF_RETURN_IF_ERROR(SetCompileStructFn(library_handle));
TF_RETURN_IF_ERROR(SetExecuteStructFn(library_handle));
TF_RETURN_IF_ERROR(SetTpuProgramStructFn(library_handle));
TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle));
TF_RETURN_IF_ERROR(SetTpuNodeContextStructFns(library_handle));

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
@ -80,10 +79,11 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
run_options.run_options().stream()->implementation());
StatusHelper status;
TpuExecutable_LoadProgramAndEnqueueToStream(
core_program_, arguments_bases, arguments.size(), &result_base,
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
rng_seed, &c_dev_assign, stream, status.c_status);
tensorflow::tpu::ExecuteApiFn()
->TpuExecutable_LoadProgramAndEnqueueToStreamFn(
core_program_, arguments_bases, arguments.size(), &result_base,
(cross_program_prefetch_addr.has_value() ? &prefetch_base : nullptr),
rng_seed, &c_dev_assign, stream, status.c_status);
if (dev_assign != nullptr) {
stream_executor::tpu::SerializedProto_Free(dev_assign_serialized);
@ -96,7 +96,7 @@ Shape TpuExecutable::HostShapeToDeviceShape(const Shape& host_shape) {
XLA_Shape c_host_shape;
XLA_Shape c_device_shape;
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
tensorflow::tpu::ExecutorApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
&c_host_shape, &c_device_shape);
Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
TpuConversions::CShapeCleanup(&c_host_shape);
@ -108,7 +108,7 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
XLA_Shape c_shape;
TpuConversions::XlaShapeToCShape(shape, &c_shape);
int64 size =
tensorflow::tpu::ExecutorApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
tensorflow::tpu::ExecuteApiFn()->HardwareLayout_ShapeSizeFn(&c_shape);
TpuConversions::CShapeCleanup(&c_shape);
return size;
}

View File

@ -290,10 +290,6 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
SE_DeviceMemoryBase* region, SE_Status* status);
void HardwareLayout_HostShapeToDeviceShape(XLA_Shape* host_shape,
XLA_Shape* device_shape);
int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
XLA_ComputationPlacer* TpuComputationPlacer_New();
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
@ -401,9 +397,6 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_New);
TFTPU_ADD_FN_IN_STRUCT(TpuComputationPlacer_Free);