Merge pull request #44747 from Intel-tensorflow:stream_executor_C_API_extension
PiperOrigin-RevId: 344156372 Change-Id: I036a76c7ccf5f25e45767b3bf3270cbed9a56830
This commit is contained in:
commit
aa22defe05
@ -522,6 +522,7 @@ cc_library(
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
@ -542,13 +543,17 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
":tf_tensor",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
@ -7,10 +7,21 @@ load(
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = [
|
||||
"stream_executor.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_hdrs",
|
||||
hdrs = ["stream_executor.h"],
|
||||
@ -49,9 +60,11 @@ cc_library(
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/stream_executor:executor_cache",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -44,6 +43,7 @@ using tensorflow::StatusFromTF_Status;
|
||||
|
||||
namespace stream_executor {
|
||||
using tensorflow::StringPiece;
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -188,41 +188,6 @@ port::Status ValidateSEPlatformRegistrationParams(
|
||||
}
|
||||
#undef VALIDATE_MEMBER
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
// Converts SE_EventStatus to Event::Status.
|
||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
switch (s) {
|
||||
@ -237,82 +202,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
}
|
||||
}
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
// Converts DeviceMemoryBase to a C struct.
|
||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
@ -37,6 +38,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
|
||||
// testing).
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
|
||||
// This file implements core stream executor base classes in terms of
|
||||
// the C API defined in stream_executor.h. A class "CSomething" represents a
|
||||
// "Something" that can be manipulated via calls in the C interface.
|
||||
class CPlatform : public Platform {
|
||||
public:
|
||||
explicit CPlatform(SP_Platform platform,
|
||||
@ -83,5 +91,111 @@ class CPlatform : public Platform {
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
};
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
@ -24,7 +24,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
// Required for IS_MOBILE_PLATFORM definition
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
// implementations. It is crucial that changes to this file are made cautiously
|
||||
@ -168,6 +174,35 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// This function is only for pluggable device.
|
||||
// It will return nullptr in all other cases.
|
||||
// This function is experimental and subject to change.
|
||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Accessing device stream is not supported on mobile. File a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
||||
"important to you");
|
||||
return nullptr;
|
||||
#else
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"Accessing device stream is not supported for a CPU device.");
|
||||
return nullptr;
|
||||
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"Accessing device stream is only supported for pluggable devices.");
|
||||
return nullptr;
|
||||
} else { // Is a PluggableDevice
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
}
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
return cc_ctx->num_inputs();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
@ -65,6 +66,11 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
|
||||
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
||||
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
||||
|
||||
// TF_InitKernel to do op/kernel registration.
|
||||
// Plugin should implement TF_InitKernel to register kernels. This function
|
||||
// should register all kernels in a plugin.
|
||||
void TF_InitKernel();
|
||||
|
||||
// Allocates a new kernel builder and returns a pointer to it.
|
||||
//
|
||||
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||
@ -128,6 +134,16 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||
// --------------------------------------------------------------------------
|
||||
// OpKernelContext routines
|
||||
|
||||
// TF_GetStream returns the SP_Stream available in ctx.
|
||||
// This function returns a stream only for devices registered using the
|
||||
// StreamExecutor C API
|
||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
||||
// nullptr and set error status in all other cases.
|
||||
// Experimental: this function doesn't have compatibility guarantees and subject
|
||||
// to change at any time.
|
||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// TF_NumInputs returns the number of inputs available in ctx.
|
||||
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||
|
||||
|
@ -378,6 +378,23 @@ template <typename T>
|
||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||
TF_OpKernelContext* ctx);
|
||||
|
||||
REGISTER_OP("StreamOp").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestStream) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
SP_Stream stream = TF_GetStream(ctx, s);
|
||||
// Stream is always null if device is not a pluggable device. More test
|
||||
// cases will be added when pluggable device mechanism is supported.
|
||||
EXPECT_EQ(stream, nullptr);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(s));
|
||||
TF_DeleteStatus(s);
|
||||
};
|
||||
|
||||
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||
|
@ -869,6 +869,9 @@ filegroup(
|
||||
srcs = [
|
||||
# Sources for which we do not yet have granular targets.
|
||||
"//tensorflow/c/eager:srcs",
|
||||
# StreamExecutor C API is currently not supported for mobile.
|
||||
# Including just the header for SP_Stream reference in kernels C API.
|
||||
"//tensorflow/c/experimental/stream_executor:headers",
|
||||
"//tensorflow/c:srcs",
|
||||
"//tensorflow/core/common_runtime:mobile_srcs_only_runtime",
|
||||
"//tensorflow/core/common_runtime/eager:srcs",
|
||||
|
@ -114,6 +114,9 @@ class DeviceContext : public core::RefCounted {
|
||||
std::function<void()> func) {
|
||||
return errors::Internal("ThenExecute not supported by device");
|
||||
}
|
||||
|
||||
// check if device is a pluggable device
|
||||
virtual bool IsPluggableDevice() { return false; }
|
||||
};
|
||||
|
||||
class DeviceBase {
|
||||
@ -173,7 +176,6 @@ class DeviceBase {
|
||||
// Does not take ownership.
|
||||
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
|
||||
|
||||
|
||||
// Return the Allocator implementation to use based on the allocator
|
||||
// attributes requested. See allocator.h for more details.
|
||||
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
||||
@ -204,7 +206,6 @@ class DeviceBase {
|
||||
|
||||
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
|
||||
|
||||
|
||||
// Caller owns the return value. The OpKernelContext calls this even
|
||||
// for devices that do not implement an eigen_gpu_device. Overridden
|
||||
// by GPU devices to return a derived type.
|
||||
|
Loading…
Reference in New Issue
Block a user