remove changes for stream executor / kernel C API since it has been moved to another PR
This commit is contained in:
parent
97d2075ada
commit
b8c1ec7503
@ -548,7 +548,6 @@ tf_cuda_library(
|
|||||||
":tf_tensor",
|
":tf_tensor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -49,7 +49,6 @@ cc_library(
|
|||||||
"stream_executor.h",
|
"stream_executor.h",
|
||||||
"stream_executor_internal.h",
|
"stream_executor_internal.h",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow/c:__subpackages__"],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:c_api_macros",
|
"//tensorflow/c:c_api_macros",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -187,6 +188,41 @@ port::Status ValidateSEPlatformRegistrationParams(
|
|||||||
}
|
}
|
||||||
#undef VALIDATE_MEMBER
|
#undef VALIDATE_MEMBER
|
||||||
|
|
||||||
|
struct TFStatusDeleter {
|
||||||
|
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||||
|
};
|
||||||
|
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||||
|
|
||||||
|
class CStream : public internal::StreamInterface {
|
||||||
|
public:
|
||||||
|
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||||
|
: device_(device),
|
||||||
|
stream_executor_(stream_executor),
|
||||||
|
stream_handle_(nullptr) {}
|
||||||
|
~CStream() override { Destroy(); }
|
||||||
|
|
||||||
|
port::Status Create() {
|
||||||
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
|
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||||
|
port::Status s = StatusFromTF_Status(c_status.get());
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Destroy() {
|
||||||
|
if (stream_handle_ != nullptr) {
|
||||||
|
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||||
|
stream_handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SP_Stream Handle() { return stream_handle_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
SP_Device* device_;
|
||||||
|
SP_StreamExecutor* stream_executor_;
|
||||||
|
SP_Stream stream_handle_;
|
||||||
|
};
|
||||||
|
|
||||||
// Converts SE_EventStatus to Event::Status.
|
// Converts SE_EventStatus to Event::Status.
|
||||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||||
switch (s) {
|
switch (s) {
|
||||||
@ -201,6 +237,82 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class CEvent : public internal::EventInterface {
|
||||||
|
public:
|
||||||
|
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||||
|
: device_(device),
|
||||||
|
stream_executor_(stream_executor),
|
||||||
|
event_handle_(nullptr) {}
|
||||||
|
~CEvent() override { Destroy(); }
|
||||||
|
|
||||||
|
port::Status Create() {
|
||||||
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
|
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||||
|
return StatusFromTF_Status(c_status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
port::Status Record(SP_Stream stream_handle) {
|
||||||
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
|
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||||
|
c_status.get());
|
||||||
|
return StatusFromTF_Status(c_status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Destroy() {
|
||||||
|
if (event_handle_ != nullptr) {
|
||||||
|
stream_executor_->destroy_event(device_, event_handle_);
|
||||||
|
event_handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SP_Event Handle() { return event_handle_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
SP_Device* device_;
|
||||||
|
SP_StreamExecutor* stream_executor_;
|
||||||
|
SP_Event event_handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CTimer : public internal::TimerInterface {
|
||||||
|
public:
|
||||||
|
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||||
|
SP_TimerFns* timer_fns)
|
||||||
|
: device_(device),
|
||||||
|
stream_executor_(stream_executor),
|
||||||
|
timer_handle_(nullptr),
|
||||||
|
timer_fns_(timer_fns) {}
|
||||||
|
~CTimer() override { Destroy(); }
|
||||||
|
|
||||||
|
port::Status Create() {
|
||||||
|
OwnedTFStatus c_status(TF_NewStatus());
|
||||||
|
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||||
|
return StatusFromTF_Status(c_status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Destroy() {
|
||||||
|
if (timer_handle_ != nullptr) {
|
||||||
|
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||||
|
timer_handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SP_Timer Handle() { return timer_handle_; }
|
||||||
|
|
||||||
|
uint64 Microseconds() const override {
|
||||||
|
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64 Nanoseconds() const override {
|
||||||
|
return timer_fns_->nanoseconds(timer_handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SP_Device* device_;
|
||||||
|
SP_StreamExecutor* stream_executor_;
|
||||||
|
SP_Timer timer_handle_;
|
||||||
|
SP_TimerFns* timer_fns_;
|
||||||
|
};
|
||||||
|
|
||||||
// Converts DeviceMemoryBase to a C struct.
|
// Converts DeviceMemoryBase to a C struct.
|
||||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
|
||||||
#include "tensorflow/stream_executor/executor_cache.h"
|
#include "tensorflow/stream_executor/executor_cache.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
@ -38,13 +37,6 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
|
|||||||
// testing).
|
// testing).
|
||||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||||
|
|
||||||
namespace {
|
|
||||||
struct TFStatusDeleter {
|
|
||||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
|
||||||
};
|
|
||||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
class CPlatform : public Platform {
|
class CPlatform : public Platform {
|
||||||
public:
|
public:
|
||||||
explicit CPlatform(SP_Platform platform,
|
explicit CPlatform(SP_Platform platform,
|
||||||
@ -91,111 +83,5 @@ class CPlatform : public Platform {
|
|||||||
stream_executor::ExecutorCache executor_cache_;
|
stream_executor::ExecutorCache executor_cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CStream : public internal::StreamInterface {
|
|
||||||
public:
|
|
||||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
|
||||||
: device_(device),
|
|
||||||
stream_executor_(stream_executor),
|
|
||||||
stream_handle_(nullptr) {}
|
|
||||||
~CStream() override { Destroy(); }
|
|
||||||
|
|
||||||
port::Status Create() {
|
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
|
||||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
|
||||||
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Destroy() {
|
|
||||||
if (stream_handle_ != nullptr) {
|
|
||||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
|
||||||
stream_handle_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SP_Stream Handle() { return stream_handle_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
SP_Device* device_;
|
|
||||||
SP_StreamExecutor* stream_executor_;
|
|
||||||
SP_Stream stream_handle_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class CEvent : public internal::EventInterface {
|
|
||||||
public:
|
|
||||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
|
||||||
: device_(device),
|
|
||||||
stream_executor_(stream_executor),
|
|
||||||
event_handle_(nullptr) {}
|
|
||||||
~CEvent() override { Destroy(); }
|
|
||||||
|
|
||||||
port::Status Create() {
|
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
|
||||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
|
||||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
port::Status Record(SP_Stream stream_handle) {
|
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
|
||||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
|
||||||
c_status.get());
|
|
||||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Destroy() {
|
|
||||||
if (event_handle_ != nullptr) {
|
|
||||||
stream_executor_->destroy_event(device_, event_handle_);
|
|
||||||
event_handle_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SP_Event Handle() { return event_handle_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
SP_Device* device_;
|
|
||||||
SP_StreamExecutor* stream_executor_;
|
|
||||||
SP_Event event_handle_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class CTimer : public internal::TimerInterface {
|
|
||||||
public:
|
|
||||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
|
||||||
SP_TimerFns* timer_fns)
|
|
||||||
: device_(device),
|
|
||||||
stream_executor_(stream_executor),
|
|
||||||
timer_handle_(nullptr),
|
|
||||||
timer_fns_(timer_fns) {}
|
|
||||||
~CTimer() override { Destroy(); }
|
|
||||||
|
|
||||||
port::Status Create() {
|
|
||||||
OwnedTFStatus c_status(TF_NewStatus());
|
|
||||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
|
||||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Destroy() {
|
|
||||||
if (timer_handle_ != nullptr) {
|
|
||||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
|
||||||
timer_handle_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SP_Timer Handle() { return timer_handle_; }
|
|
||||||
|
|
||||||
uint64 Microseconds() const override {
|
|
||||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64 Nanoseconds() const override {
|
|
||||||
return timer_fns_->nanoseconds(timer_handle_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
SP_Device* device_;
|
|
||||||
SP_StreamExecutor* stream_executor_;
|
|
||||||
SP_Timer timer_handle_;
|
|
||||||
SP_TimerFns* timer_fns_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
@ -26,7 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/stream_executor/stream.h"
|
|
||||||
|
|
||||||
// This file forms the basis of a stable ABI for third-party kernel
|
// This file forms the basis of a stable ABI for third-party kernel
|
||||||
// implementations. It is crucial that changes to this file are made cautiously
|
// implementations. It is crucial that changes to this file are made cautiously
|
||||||
@ -170,22 +168,6 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
|||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function is only for pluggable device.
|
|
||||||
// It will return nullptr in all other cases.
|
|
||||||
// This function is experimental and subject to change.
|
|
||||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
|
|
||||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
|
||||||
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
|
||||||
return nullptr;
|
|
||||||
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
|
||||||
return nullptr;
|
|
||||||
} else { // Is a PluggableDevice
|
|
||||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
|
||||||
cc_ctx->op_device_context()->stream()->implementation());
|
|
||||||
return c_stream->Handle();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||||
return cc_ctx->num_inputs();
|
return cc_ctx->num_inputs();
|
||||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
@ -66,11 +65,6 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
|
|||||||
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
||||||
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
||||||
|
|
||||||
// TF_InitKernel to do op/kernel registration.
|
|
||||||
// Plugin should either implement TF_InitKernel to register kernels or use
|
|
||||||
// static registration. This function should register all kernels in a plugin.
|
|
||||||
void TF_InitKernel();
|
|
||||||
|
|
||||||
// Allocates a new kernel builder and returns a pointer to it.
|
// Allocates a new kernel builder and returns a pointer to it.
|
||||||
//
|
//
|
||||||
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||||
@ -134,15 +128,6 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
|||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// OpKernelContext routines
|
// OpKernelContext routines
|
||||||
|
|
||||||
// TF_GetStream returns the SP_Stream available in ctx.
|
|
||||||
// This function returns a stream only for devices registered using the
|
|
||||||
// StreamExecutor C API
|
|
||||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
|
||||||
// nullptr in all other cases.
|
|
||||||
// Experimental: this function doesn't have compatibility guarantees and subject
|
|
||||||
// to change at any time.
|
|
||||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
|
|
||||||
|
|
||||||
// TF_NumInputs returns the number of inputs available in ctx.
|
// TF_NumInputs returns the number of inputs available in ctx.
|
||||||
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
@ -51,7 +52,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
|
|
||||||
struct MyCustomKernel {
|
struct MyCustomKernel {
|
||||||
bool created;
|
bool created;
|
||||||
@ -378,20 +378,6 @@ template <typename T>
|
|||||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||||
TF_OpKernelContext* ctx);
|
TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
REGISTER_OP("StreamOp").Output("output1: float");
|
|
||||||
|
|
||||||
TEST_F(DeviceKernelOpTest, TestStream) {
|
|
||||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
|
||||||
SP_Stream stream = TF_GetStream(ctx);
|
|
||||||
// Stream is always null if device is not a pluggable device. More test
|
|
||||||
// cases will be added when pluggable device mechanism is supported.
|
|
||||||
EXPECT_EQ(stream, nullptr);
|
|
||||||
};
|
|
||||||
|
|
||||||
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||||
|
|
||||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||||
|
@ -114,9 +114,6 @@ class DeviceContext : public core::RefCounted {
|
|||||||
std::function<void()> func) {
|
std::function<void()> func) {
|
||||||
return errors::Internal("ThenExecute not supported by device");
|
return errors::Internal("ThenExecute not supported by device");
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if device is a pluggable device
|
|
||||||
virtual bool IsPluggableDevice() { return false; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeviceBase {
|
class DeviceBase {
|
||||||
|
Loading…
Reference in New Issue
Block a user