address review comments
This commit is contained in:
parent
f599aed5d8
commit
b8a26f2d74
@ -492,8 +492,6 @@ tf_cuda_library(
|
||||
],
|
||||
hdrs = [
|
||||
"kernels.h",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_internal.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
@ -501,7 +499,7 @@ tf_cuda_library(
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
@ -581,7 +579,6 @@ tf_cuda_cc_test(
|
||||
srcs = ["c_api_test.cc"],
|
||||
data = [
|
||||
":test_op1.so",
|
||||
":test_pluggable_device.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
linkopts = select({
|
||||
@ -633,7 +630,10 @@ tf_cc_test(
|
||||
name = "c_api_experimental_test",
|
||||
size = "medium",
|
||||
srcs = ["c_api_experimental_test.cc"],
|
||||
data = ["testdata/tf_record"],
|
||||
data = [
|
||||
"testdata/tf_record",
|
||||
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
|
||||
],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
@ -653,6 +653,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -692,19 +693,6 @@ tf_custom_op_library(
|
||||
srcs = ["test_op1.cc"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_pluggable_device_impl",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
hdrs = ["c_api_macros.h",
|
||||
"tf_status.h",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "test_pluggable_device.so",
|
||||
deps = [":test_pluggable_device_impl"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "test_op_kernel",
|
||||
srcs = ["test_op.cc"],
|
||||
|
@ -308,9 +308,6 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
||||
Status LoadDynamicLibrary(const char* library_filename, void** result,
|
||||
const void** buf, size_t* len);
|
||||
|
||||
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file)
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||
|
||||
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
||||
// directly, instead of requiring us to serialize to a GraphDef and
|
||||
// call Session::Extend().
|
||||
@ -584,23 +581,6 @@ TF_Buffer* TF_GetAllOpList() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
TF_Status* status) {
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
||||
library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
return lib_handle;
|
||||
}
|
||||
|
||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||
if (lib_handle == nullptr) return;
|
||||
delete lib_handle;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// ListDevices & SessionListDevices API
|
||||
|
||||
|
@ -630,6 +630,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
||||
|
||||
namespace tensorflow {
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||
} // namespace tensorflow
|
||||
|
||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
@ -743,3 +746,26 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||
opts->opts.validate_colocation_constraints = enable;
|
||||
}
|
||||
|
||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"PluggableDevice plugin functionality is not supported on mobile");
|
||||
return nullptr;
|
||||
#else
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
||||
library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
return lib_handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||
if (lib_handle == nullptr) return;
|
||||
delete lib_handle;
|
||||
}
|
||||
|
@ -304,6 +304,25 @@ TF_CAPI_EXPORT extern void
|
||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||
|
||||
// Load the library specified by library_filename and register the pluggable
|
||||
// device and related kernels present in that library.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The reles for determining the exact location of the
|
||||
// library are platform-specific and are not documented here
|
||||
//
|
||||
// On success, place OK in status and return the newly created library handle.
|
||||
// The caller owns the library handle
|
||||
//
|
||||
// On failure, place an error status in status and return NULL.
|
||||
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||
const char* library_filename, TF_Status* status);
|
||||
|
||||
// Frees the memory associated with the library handle.
|
||||
// Does NOT unload the library.
|
||||
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||
TF_Library* lib_handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
TF_DeleteTensor(tensor_1X6);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
string lib_path =
|
||||
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||
"tensorflow", "c", "experimental", "stream_executor", "test",
|
||||
"test_pluggable_device.so"));
|
||||
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -228,21 +228,6 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CAPI, LibraryPluggableDeviceLoadFunctions) {
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
string lib_path = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "c", "test_pluggable_device.so"));
|
||||
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
}
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
const tensorflow::int64 n = data.size();
|
||||
Status status;
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
@ -67,4 +68,3 @@ exports_files(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
using tensorflow::StatusFromTF_Status;
|
||||
namespace stream_executor {
|
||||
|
||||
// Plugin initialization function that a device plugin
|
||||
@ -103,7 +102,7 @@ class CStream : public internal::StreamInterface {
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -133,14 +132,14 @@ class CEvent : public internal::EventInterface {
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
@ -171,7 +170,7 @@ class CTimer : public internal::TimerInterface {
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
|
21
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
21
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
@ -0,0 +1,21 @@
|
||||
# Description:
|
||||
# test for stream_executor
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_pluggable_device_impl",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
hdrs = [
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "test_pluggable_device.so",
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [":test_pluggable_device_impl"],
|
||||
)
|
@ -0,0 +1,24 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) {
|
||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
params->platform->name = "GPU";
|
||||
params->platform->type = "XGPU";
|
||||
}
|
@ -170,13 +170,20 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// This function is only for pluggable device
|
||||
// it will return nullptr in all other cases
|
||||
// This function is only for pluggable device.
|
||||
// It will return nullptr in all other cases.
|
||||
// This function is experimental and subject to change.
|
||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
||||
return nullptr;
|
||||
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
||||
return nullptr;
|
||||
} else { // Is a PluggableDevice
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
}
|
||||
}
|
||||
|
||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -52,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
struct MyCustomKernel {
|
||||
bool created;
|
||||
@ -378,6 +378,20 @@ template <typename T>
|
||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||
TF_OpKernelContext* ctx);
|
||||
|
||||
REGISTER_OP("StreamOp").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestStream) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
SP_Stream stream = TF_GetStream(ctx);
|
||||
// Stream is always null if device is not a pluggable device. More test
|
||||
// cases will be added when pluggable device mechanism is supported.
|
||||
EXPECT_EQ(stream, nullptr);
|
||||
};
|
||||
|
||||
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||
|
@ -114,6 +114,9 @@ class DeviceContext : public core::RefCounted {
|
||||
std::function<void()> func) {
|
||||
return errors::Internal("ThenExecute not supported by device");
|
||||
}
|
||||
|
||||
// check if device is a pluggable device
|
||||
virtual bool IsPluggableDevice() { return false; }
|
||||
};
|
||||
|
||||
class DeviceBase {
|
||||
@ -173,7 +176,6 @@ class DeviceBase {
|
||||
// Does not take ownership.
|
||||
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
|
||||
|
||||
|
||||
// Return the Allocator implementation to use based on the allocator
|
||||
// attributes requested. See allocator.h for more details.
|
||||
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
||||
@ -204,7 +206,6 @@ class DeviceBase {
|
||||
|
||||
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
|
||||
|
||||
|
||||
// Caller owns the return value. The OpKernelContext calls this even
|
||||
// for devices that do not implement an eigen_gpu_device. Overridden
|
||||
// by GPU devices to return a derived type.
|
||||
|
@ -158,7 +158,6 @@ def load_library(library_location):
|
||||
'The file or folder to load kernel libraries from does not exist.',
|
||||
library_location)
|
||||
|
||||
@tf_export('load_pluggable_device_library')
|
||||
def load_pluggable_device_library(library_location):
|
||||
"""Loads a Tensorflow PluggableDevice plugin
|
||||
"library_location" can be a path to a specific shared object, or a folder.
|
||||
|
@ -1576,10 +1576,6 @@ tf_module {
|
||||
name: "load_op_library"
|
||||
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "load_pluggable_device_library"
|
||||
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "local_variables"
|
||||
argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -772,10 +772,6 @@ tf_module {
|
||||
name: "load_op_library"
|
||||
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "load_pluggable_device_library"
|
||||
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "logical_and"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user