remove changes for stream executor / kernel C API since it has been moved to another PR
This commit is contained in:
parent
97d2075ada
commit
b8c1ec7503
@ -548,7 +548,6 @@ tf_cuda_library(
|
||||
":tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
@ -49,7 +49,6 @@ cc_library(
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -187,6 +188,41 @@ port::Status ValidateSEPlatformRegistrationParams(
|
||||
}
|
||||
#undef VALIDATE_MEMBER
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
// Converts SE_EventStatus to Event::Status.
|
||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
switch (s) {
|
||||
@ -201,6 +237,82 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
}
|
||||
}
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
// Converts DeviceMemoryBase to a C struct.
|
||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
@ -38,13 +37,6 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
|
||||
// testing).
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||
|
||||
namespace {
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
} // namespace
|
||||
|
||||
class CPlatform : public Platform {
|
||||
public:
|
||||
explicit CPlatform(SP_Platform platform,
|
||||
@ -91,111 +83,5 @@ class CPlatform : public Platform {
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
};
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
@ -26,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
// implementations. It is crucial that changes to this file are made cautiously
|
||||
@ -170,22 +168,6 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// This function is only for pluggable device.
|
||||
// It will return nullptr in all other cases.
|
||||
// This function is experimental and subject to change.
|
||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
||||
return nullptr;
|
||||
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
||||
return nullptr;
|
||||
} else { // Is a PluggableDevice
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
}
|
||||
}
|
||||
|
||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
return cc_ctx->num_inputs();
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
@ -66,11 +65,6 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
|
||||
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
||||
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
||||
|
||||
// TF_InitKernel to do op/kernel registration.
|
||||
// Plugin should either implement TF_InitKernel to register kernels or use
|
||||
// static registration. This function should register all kernels in a plugin.
|
||||
void TF_InitKernel();
|
||||
|
||||
// Allocates a new kernel builder and returns a pointer to it.
|
||||
//
|
||||
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||
@ -134,15 +128,6 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||
// --------------------------------------------------------------------------
|
||||
// OpKernelContext routines
|
||||
|
||||
// TF_GetStream returns the SP_Stream available in ctx.
|
||||
// This function returns a stream only for devices registered using the
|
||||
// StreamExecutor C API
|
||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
||||
// nullptr in all other cases.
|
||||
// Experimental: this function doesn't have compatibility guarantees and subject
|
||||
// to change at any time.
|
||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
|
||||
|
||||
// TF_NumInputs returns the number of inputs available in ctx.
|
||||
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -51,7 +52,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
struct MyCustomKernel {
|
||||
bool created;
|
||||
@ -378,20 +378,6 @@ template <typename T>
|
||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||
TF_OpKernelContext* ctx);
|
||||
|
||||
REGISTER_OP("StreamOp").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestStream) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
SP_Stream stream = TF_GetStream(ctx);
|
||||
// Stream is always null if device is not a pluggable device. More test
|
||||
// cases will be added when pluggable device mechanism is supported.
|
||||
EXPECT_EQ(stream, nullptr);
|
||||
};
|
||||
|
||||
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||
|
@ -114,9 +114,6 @@ class DeviceContext : public core::RefCounted {
|
||||
std::function<void()> func) {
|
||||
return errors::Internal("ThenExecute not supported by device");
|
||||
}
|
||||
|
||||
// check if device is a pluggable device
|
||||
virtual bool IsPluggableDevice() { return false; }
|
||||
};
|
||||
|
||||
class DeviceBase {
|
||||
|
Loading…
Reference in New Issue
Block a user