address review comments

This commit is contained in:
Zhoulong Jiang 2020-10-09 20:36:59 +00:00
parent f599aed5d8
commit b8a26f2d74
16 changed files with 149 additions and 76 deletions

View File

@ -492,8 +492,6 @@ tf_cuda_library(
], ],
hdrs = [ hdrs = [
"kernels.h", "kernels.h",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
"//tensorflow/c/experimental/stream_executor:stream_executor_internal.h",
], ],
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -501,7 +499,7 @@ tf_cuda_library(
":tf_status", ":tf_status",
":tf_status_helper", ":tf_status_helper",
":tf_tensor_internal", ":tf_tensor_internal",
"//tensorflow/c/experimental/stream_executor:stream_executor", "//tensorflow/c/experimental/stream_executor:stream_executor",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
":c_api_internal", ":c_api_internal",
@ -581,7 +579,6 @@ tf_cuda_cc_test(
srcs = ["c_api_test.cc"], srcs = ["c_api_test.cc"],
data = [ data = [
":test_op1.so", ":test_op1.so",
":test_pluggable_device.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two", "//tensorflow/cc/saved_model:saved_model_half_plus_two",
], ],
linkopts = select({ linkopts = select({
@ -633,7 +630,10 @@ tf_cc_test(
name = "c_api_experimental_test", name = "c_api_experimental_test",
size = "medium", size = "medium",
srcs = ["c_api_experimental_test.cc"], srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"], data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
],
linkopts = select({ linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"], "//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [], "//conditions:default": [],
@ -653,6 +653,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )
@ -692,19 +693,6 @@ tf_custom_op_library(
srcs = ["test_op1.cc"], srcs = ["test_op1.cc"],
) )
cc_library(
name = "test_pluggable_device_impl",
srcs = ["test_pluggable_device.cc"],
hdrs = ["c_api_macros.h",
"tf_status.h",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",],
)
tf_custom_op_library(
name = "test_pluggable_device.so",
deps = [":test_pluggable_device_impl"],
)
tf_kernel_library( tf_kernel_library(
name = "test_op_kernel", name = "test_op_kernel",
srcs = ["test_op.cc"], srcs = ["test_op.cc"],

View File

@ -308,9 +308,6 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
Status LoadDynamicLibrary(const char* library_filename, void** result, Status LoadDynamicLibrary(const char* library_filename, void** result,
const void** buf, size_t* len); const void** buf, size_t* len);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file)
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend(). // call Session::Extend().
@ -584,23 +581,6 @@ TF_Buffer* TF_GetAllOpList() {
return ret; return ret;
} }
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadPluggableDeviceLibrary(
library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
return lib_handle;
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
if (lib_handle == nullptr) return;
delete lib_handle;
}
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// ListDevices & SessionListDevices API // ListDevices & SessionListDevices API

View File

@ -630,6 +630,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
namespace tensorflow { namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
} // namespace tensorflow } // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
@ -743,3 +746,26 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) { TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable; opts->opts.validate_colocation_constraints = enable;
} }
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"PluggableDevice plugin functionality is not supported on mobile");
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadPluggableDeviceLibrary(
library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
return lib_handle;
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
if (lib_handle == nullptr) return;
delete lib_handle;
}

View File

@ -304,6 +304,25 @@ TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable); TF_ImportGraphDefOptions* opts, unsigned char enable);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The reles for determining the exact location of the
// library are platform-specific and are not documented here
//
// On success, place OK in status and return the newly created library handle.
// The caller owns the library handle
//
// On failure, place an error status in status and return NULL.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TF_DeleteTensor(tensor_1X6); TF_DeleteTensor(tensor_1X6);
} }
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path =
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "c", "experimental", "stream_executor", "test",
"test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -228,21 +228,6 @@ TEST(CAPI, LibraryLoadFunctions) {
} }
} }
TEST(CAPI, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "c", "test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
void TestEncodeDecode(int line, const std::vector<string>& data) { void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size(); const tensorflow::int64 n = data.size();
Status status; Status status;

View File

@ -37,6 +37,7 @@ cc_library(
"stream_executor.h", "stream_executor.h",
"stream_executor_internal.h", "stream_executor_internal.h",
], ],
visibility = ["//tensorflow/c:__subpackages__"],
deps = [ deps = [
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
@ -67,4 +68,3 @@ exports_files(
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
using tensorflow::StatusFromTF_Status;
namespace stream_executor { namespace stream_executor {
// Plugin initialization function that a device plugin // Plugin initialization function that a device plugin
@ -103,7 +102,7 @@ class CStream : public internal::StreamInterface {
port::Status Create() { port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus()); OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = StatusFromTF_Status(c_status.get()); port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s; return s;
} }
@ -133,14 +132,14 @@ class CEvent : public internal::EventInterface {
port::Status Create() { port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus()); OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get()); stream_executor_->create_event(device_, &event_handle_, c_status.get());
return StatusFromTF_Status(c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get());
} }
port::Status Record(SP_Stream stream_handle) { port::Status Record(SP_Stream stream_handle) {
OwnedTFStatus c_status(TF_NewStatus()); OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_, stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get()); c_status.get());
return StatusFromTF_Status(c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get());
} }
void Destroy() { void Destroy() {
@ -171,7 +170,7 @@ class CTimer : public internal::TimerInterface {
port::Status Create() { port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus()); OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return StatusFromTF_Status(c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get());
} }
void Destroy() { void Destroy() {

View File

@ -0,0 +1,21 @@
# Description:
# test for stream_executor
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
)
cc_library(
name = "test_pluggable_device_impl",
srcs = ["test_pluggable_device.cc"],
hdrs = [
"//tensorflow/c:headers",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
],
)
tf_custom_op_library(
name = "test_pluggable_device.so",
visibility = ["//tensorflow/c:__subpackages__"],
deps = [":test_pluggable_device_impl"],
)

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -170,13 +170,20 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
// This function is only for pluggable device // This function is only for pluggable device.
// it will return nullptr in all other cases // It will return nullptr in all other cases.
// This function is experimental and subject to change.
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) { SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
auto c_stream = static_cast<stream_executor::CStream*>( if (cc_ctx->op_device_context() == nullptr) { // CPU Device
cc_ctx->op_device_context()->stream()->implementation()); return nullptr;
return c_stream->Handle(); } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
return nullptr;
} else { // Is a PluggableDevice
auto c_stream = static_cast<stream_executor::CStream*>(
cc_ctx->op_device_context()->stream()->implementation());
return c_stream->Handle();
}
} }
int TF_NumInputs(TF_OpKernelContext* ctx) { int TF_NumInputs(TF_OpKernelContext* ctx) {

View File

@ -27,7 +27,6 @@ limitations under the License.
#include <utility> #include <utility>
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
@ -52,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
struct MyCustomKernel { struct MyCustomKernel {
bool created; bool created;
@ -378,6 +378,20 @@ template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx); TF_OpKernelContext* ctx);
REGISTER_OP("StreamOp").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestStream) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
SP_Stream stream = TF_GetStream(ctx);
// Stream is always null if device is not a pluggable device. More test
// cases will be added when pluggable device mechanism is supported.
EXPECT_EQ(stream, nullptr);
};
SetupOp("StreamOp", "StreamOp", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
}
REGISTER_OP("AllocateOutputOp1").Output("output1: float"); REGISTER_OP("AllocateOutputOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) { TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {

View File

@ -114,6 +114,9 @@ class DeviceContext : public core::RefCounted {
std::function<void()> func) { std::function<void()> func) {
return errors::Internal("ThenExecute not supported by device"); return errors::Internal("ThenExecute not supported by device");
} }
// check if device is a pluggable device
virtual bool IsPluggableDevice() { return false; }
}; };
class DeviceBase { class DeviceBase {
@ -173,7 +176,6 @@ class DeviceBase {
// Does not take ownership. // Does not take ownership.
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d); void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
// Return the Allocator implementation to use based on the allocator // Return the Allocator implementation to use based on the allocator
// attributes requested. See allocator.h for more details. // attributes requested. See allocator.h for more details.
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) { virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
@ -204,7 +206,6 @@ class DeviceBase {
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device(); virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
// Caller owns the return value. The OpKernelContext calls this even // Caller owns the return value. The OpKernelContext calls this even
// for devices that do not implement an eigen_gpu_device. Overridden // for devices that do not implement an eigen_gpu_device. Overridden
// by GPU devices to return a derived type. // by GPU devices to return a derived type.

View File

@ -158,7 +158,6 @@ def load_library(library_location):
'The file or folder to load kernel libraries from does not exist.', 'The file or folder to load kernel libraries from does not exist.',
library_location) library_location)
@tf_export('load_pluggable_device_library')
def load_pluggable_device_library(library_location): def load_pluggable_device_library(library_location):
"""Loads a Tensorflow PluggableDevice plugin """Loads a Tensorflow PluggableDevice plugin
"library_location" can be a path to a specific shared object, or a folder. "library_location" can be a path to a specific shared object, or a folder.

View File

@ -1576,10 +1576,6 @@ tf_module {
name: "load_op_library" name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "load_pluggable_device_library"
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "local_variables" name: "local_variables"
argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -772,10 +772,6 @@ tf_module {
name: "load_op_library" name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "load_pluggable_device_library"
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "logical_and" name: "logical_and"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "