address review comments
This commit is contained in:
parent
f599aed5d8
commit
b8a26f2d74
@ -492,8 +492,6 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"kernels.h",
|
"kernels.h",
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor_internal.h",
|
|
||||||
],
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
@ -501,7 +499,7 @@ tf_cuda_library(
|
|||||||
":tf_status",
|
":tf_status",
|
||||||
":tf_status_helper",
|
":tf_status_helper",
|
||||||
":tf_tensor_internal",
|
":tf_tensor_internal",
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
":c_api_internal",
|
":c_api_internal",
|
||||||
@ -581,7 +579,6 @@ tf_cuda_cc_test(
|
|||||||
srcs = ["c_api_test.cc"],
|
srcs = ["c_api_test.cc"],
|
||||||
data = [
|
data = [
|
||||||
":test_op1.so",
|
":test_op1.so",
|
||||||
":test_pluggable_device.so",
|
|
||||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||||
],
|
],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
@ -633,7 +630,10 @@ tf_cc_test(
|
|||||||
name = "c_api_experimental_test",
|
name = "c_api_experimental_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["c_api_experimental_test.cc"],
|
srcs = ["c_api_experimental_test.cc"],
|
||||||
data = ["testdata/tf_record"],
|
data = [
|
||||||
|
"testdata/tf_record",
|
||||||
|
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
|
||||||
|
],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
@ -653,6 +653,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -692,19 +693,6 @@ tf_custom_op_library(
|
|||||||
srcs = ["test_op1.cc"],
|
srcs = ["test_op1.cc"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "test_pluggable_device_impl",
|
|
||||||
srcs = ["test_pluggable_device.cc"],
|
|
||||||
hdrs = ["c_api_macros.h",
|
|
||||||
"tf_status.h",
|
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_custom_op_library(
|
|
||||||
name = "test_pluggable_device.so",
|
|
||||||
deps = [":test_pluggable_device_impl"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "test_op_kernel",
|
name = "test_op_kernel",
|
||||||
srcs = ["test_op.cc"],
|
srcs = ["test_op.cc"],
|
||||||
|
@ -308,9 +308,6 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
|||||||
Status LoadDynamicLibrary(const char* library_filename, void** result,
|
Status LoadDynamicLibrary(const char* library_filename, void** result,
|
||||||
const void** buf, size_t* len);
|
const void** buf, size_t* len);
|
||||||
|
|
||||||
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file)
|
|
||||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
|
||||||
|
|
||||||
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
||||||
// directly, instead of requiring us to serialize to a GraphDef and
|
// directly, instead of requiring us to serialize to a GraphDef and
|
||||||
// call Session::Extend().
|
// call Session::Extend().
|
||||||
@ -584,23 +581,6 @@ TF_Buffer* TF_GetAllOpList() {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
|
||||||
TF_Status* status) {
|
|
||||||
TF_Library* lib_handle = new TF_Library;
|
|
||||||
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
|
||||||
library_filename, &lib_handle->lib_handle);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
delete lib_handle;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return lib_handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
|
||||||
if (lib_handle == nullptr) return;
|
|
||||||
delete lib_handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// ListDevices & SessionListDevices API
|
// ListDevices & SessionListDevices API
|
||||||
|
|
||||||
|
@ -630,6 +630,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
|
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
|
||||||
|
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||||
@ -743,3 +746,26 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
|||||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||||
opts->opts.validate_colocation_constraints = enable;
|
opts->opts.validate_colocation_constraints = enable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||||
|
TF_Status* status) {
|
||||||
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||||
|
status->status = tensorflow::errors::Unimplemented(
|
||||||
|
"PluggableDevice plugin functionality is not supported on mobile");
|
||||||
|
return nullptr;
|
||||||
|
#else
|
||||||
|
TF_Library* lib_handle = new TF_Library;
|
||||||
|
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
||||||
|
library_filename, &lib_handle->lib_handle);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
delete lib_handle;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return lib_handle;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||||
|
if (lib_handle == nullptr) return;
|
||||||
|
delete lib_handle;
|
||||||
|
}
|
||||||
|
@ -304,6 +304,25 @@ TF_CAPI_EXPORT extern void
|
|||||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||||
|
|
||||||
|
// Load the library specified by library_filename and register the pluggable
|
||||||
|
// device and related kernels present in that library.
|
||||||
|
//
|
||||||
|
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||||
|
// loading a library. The reles for determining the exact location of the
|
||||||
|
// library are platform-specific and are not documented here
|
||||||
|
//
|
||||||
|
// On success, place OK in status and return the newly created library handle.
|
||||||
|
// The caller owns the library handle
|
||||||
|
//
|
||||||
|
// On failure, place an error status in status and return NULL.
|
||||||
|
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||||
|
const char* library_filename, TF_Status* status);
|
||||||
|
|
||||||
|
// Frees the memory associated with the library handle.
|
||||||
|
// Does NOT unload the library.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||||
|
TF_Library* lib_handle);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
|||||||
TF_DeleteTensor(tensor_1X6);
|
TF_DeleteTensor(tensor_1X6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||||
|
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||||
|
// Load the library.
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
string lib_path =
|
||||||
|
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||||
|
"tensorflow", "c", "experimental", "stream_executor", "test",
|
||||||
|
"test_pluggable_device.so"));
|
||||||
|
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||||
|
TF_Code code = TF_GetCode(status);
|
||||||
|
string status_msg(TF_Message(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||||
|
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||||
|
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -228,21 +228,6 @@ TEST(CAPI, LibraryLoadFunctions) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, LibraryPluggableDeviceLoadFunctions) {
|
|
||||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
|
||||||
// Load the library.
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
string lib_path = tensorflow::GetDataDependencyFilepath(
|
|
||||||
tensorflow::io::JoinPath("tensorflow", "c", "test_pluggable_device.so"));
|
|
||||||
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
|
||||||
TF_Code code = TF_GetCode(status);
|
|
||||||
string status_msg(TF_Message(status));
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
|
||||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
|
||||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||||
const tensorflow::int64 n = data.size();
|
const tensorflow::int64 n = data.size();
|
||||||
Status status;
|
Status status;
|
||||||
|
@ -37,6 +37,7 @@ cc_library(
|
|||||||
"stream_executor.h",
|
"stream_executor.h",
|
||||||
"stream_executor_internal.h",
|
"stream_executor_internal.h",
|
||||||
],
|
],
|
||||||
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:c_api_macros",
|
"//tensorflow/c:c_api_macros",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
@ -67,4 +68,3 @@ exports_files(
|
|||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
|
|
||||||
using tensorflow::StatusFromTF_Status;
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
// Plugin initialization function that a device plugin
|
// Plugin initialization function that a device plugin
|
||||||
@ -103,7 +102,7 @@ class CStream : public internal::StreamInterface {
|
|||||||
port::Status Create() {
|
port::Status Create() {
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||||
port::Status s = StatusFromTF_Status(c_status.get());
|
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,14 +132,14 @@ class CEvent : public internal::EventInterface {
|
|||||||
port::Status Create() {
|
port::Status Create() {
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||||
return StatusFromTF_Status(c_status.get());
|
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
port::Status Record(SP_Stream stream_handle) {
|
port::Status Record(SP_Stream stream_handle) {
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||||
c_status.get());
|
c_status.get());
|
||||||
return StatusFromTF_Status(c_status.get());
|
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Destroy() {
|
void Destroy() {
|
||||||
@ -171,7 +170,7 @@ class CTimer : public internal::TimerInterface {
|
|||||||
port::Status Create() {
|
port::Status Create() {
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||||
return StatusFromTF_Status(c_status.get());
|
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Destroy() {
|
void Destroy() {
|
||||||
|
21
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
21
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Description:
|
||||||
|
# test for stream_executor
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_custom_op_library",
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "test_pluggable_device_impl",
|
||||||
|
srcs = ["test_pluggable_device.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c:headers",
|
||||||
|
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "test_pluggable_device.so",
|
||||||
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
|
deps = [":test_pluggable_device_impl"],
|
||||||
|
)
|
@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||||
|
|
||||||
|
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||||
|
TF_Status* const status) {
|
||||||
|
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||||
|
params->platform->name = "GPU";
|
||||||
|
params->platform->type = "XGPU";
|
||||||
|
}
|
@ -170,13 +170,20 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
|||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function is only for pluggable device
|
// This function is only for pluggable device.
|
||||||
// it will return nullptr in all other cases
|
// It will return nullptr in all other cases.
|
||||||
|
// This function is experimental and subject to change.
|
||||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
|
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
|
||||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
||||||
cc_ctx->op_device_context()->stream()->implementation());
|
return nullptr;
|
||||||
return c_stream->Handle();
|
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
||||||
|
return nullptr;
|
||||||
|
} else { // Is a PluggableDevice
|
||||||
|
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||||
|
cc_ctx->op_device_context()->stream()->implementation());
|
||||||
|
return c_stream->Handle();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
@ -52,6 +51,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
struct MyCustomKernel {
|
struct MyCustomKernel {
|
||||||
bool created;
|
bool created;
|
||||||
@ -378,6 +378,20 @@ template <typename T>
|
|||||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||||
TF_OpKernelContext* ctx);
|
TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
|
REGISTER_OP("StreamOp").Output("output1: float");
|
||||||
|
|
||||||
|
TEST_F(DeviceKernelOpTest, TestStream) {
|
||||||
|
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
SP_Stream stream = TF_GetStream(ctx);
|
||||||
|
// Stream is always null if device is not a pluggable device. More test
|
||||||
|
// cases will be added when pluggable device mechanism is supported.
|
||||||
|
EXPECT_EQ(stream, nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||||
|
|
||||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||||
|
@ -114,6 +114,9 @@ class DeviceContext : public core::RefCounted {
|
|||||||
std::function<void()> func) {
|
std::function<void()> func) {
|
||||||
return errors::Internal("ThenExecute not supported by device");
|
return errors::Internal("ThenExecute not supported by device");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if device is a pluggable device
|
||||||
|
virtual bool IsPluggableDevice() { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeviceBase {
|
class DeviceBase {
|
||||||
@ -173,7 +176,6 @@ class DeviceBase {
|
|||||||
// Does not take ownership.
|
// Does not take ownership.
|
||||||
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
|
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
|
||||||
|
|
||||||
|
|
||||||
// Return the Allocator implementation to use based on the allocator
|
// Return the Allocator implementation to use based on the allocator
|
||||||
// attributes requested. See allocator.h for more details.
|
// attributes requested. See allocator.h for more details.
|
||||||
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
||||||
@ -204,7 +206,6 @@ class DeviceBase {
|
|||||||
|
|
||||||
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
|
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
|
||||||
|
|
||||||
|
|
||||||
// Caller owns the return value. The OpKernelContext calls this even
|
// Caller owns the return value. The OpKernelContext calls this even
|
||||||
// for devices that do not implement an eigen_gpu_device. Overridden
|
// for devices that do not implement an eigen_gpu_device. Overridden
|
||||||
// by GPU devices to return a derived type.
|
// by GPU devices to return a derived type.
|
||||||
|
@ -158,7 +158,6 @@ def load_library(library_location):
|
|||||||
'The file or folder to load kernel libraries from does not exist.',
|
'The file or folder to load kernel libraries from does not exist.',
|
||||||
library_location)
|
library_location)
|
||||||
|
|
||||||
@tf_export('load_pluggable_device_library')
|
|
||||||
def load_pluggable_device_library(library_location):
|
def load_pluggable_device_library(library_location):
|
||||||
"""Loads a Tensorflow PluggableDevice plugin
|
"""Loads a Tensorflow PluggableDevice plugin
|
||||||
"library_location" can be a path to a specific shared object, or a folder.
|
"library_location" can be a path to a specific shared object, or a folder.
|
||||||
|
@ -1576,10 +1576,6 @@ tf_module {
|
|||||||
name: "load_op_library"
|
name: "load_op_library"
|
||||||
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "load_pluggable_device_library"
|
|
||||||
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "local_variables"
|
name: "local_variables"
|
||||||
argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -772,10 +772,6 @@ tf_module {
|
|||||||
name: "load_op_library"
|
name: "load_op_library"
|
||||||
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "load_pluggable_device_library"
|
|
||||||
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "logical_and"
|
name: "logical_and"
|
||||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user