Merge pull request #44747 from Intel-tensorflow:stream_executor_C_API_extension

PiperOrigin-RevId: 344156372
Change-Id: I036a76c7ccf5f25e45767b3bf3270cbed9a56830
This commit is contained in:
TensorFlower Gardener 2020-11-24 17:04:44 -08:00
commit aa22defe05
9 changed files with 207 additions and 114 deletions

View File

@ -522,6 +522,7 @@ cc_library(
":tf_datatype",
":tf_status",
":tf_tensor",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
],
)
@ -542,13 +543,17 @@ tf_cuda_library(
] + select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
":tf_tensor",
"//tensorflow/stream_executor:stream",
"//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/c/experimental/stream_executor:stream_executor",
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
],
}),
)

View File

@ -7,10 +7,21 @@ load(
"tf_cc_test",
)
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup")
package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "headers",
srcs = [
"stream_executor.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "stream_executor_hdrs",
hdrs = ["stream_executor.h"],
@ -49,9 +60,11 @@ cc_library(
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/stream_executor:executor_cache",
"//tensorflow/stream_executor/lib",
],

View File

@ -24,7 +24,6 @@ limitations under the License.
#include <string>
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
@ -44,6 +43,7 @@ using tensorflow::StatusFromTF_Status;
namespace stream_executor {
using tensorflow::StringPiece;
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
namespace {
@ -188,41 +188,6 @@ port::Status ValidateSEPlatformRegistrationParams(
}
#undef VALIDATE_MEMBER
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
@ -237,82 +202,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
}
}
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h"
@ -37,6 +38,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
// testing).
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
// This file implements core stream executor base classes in terms of
// the C API defined in stream_executor.h. A class "CSomething" represents a
// "Something" that can be manipulated via calls in the C interface.
class CPlatform : public Platform {
public:
explicit CPlatform(SP_Platform platform,
@ -83,5 +91,111 @@ class CPlatform : public Platform {
stream_executor::ExecutorCache executor_cache_;
};
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
} // namespace stream_executor
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_

View File

@ -24,7 +24,13 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
// Required for IS_MOBILE_PLATFORM definition
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
@ -168,6 +174,35 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_SetStatus(status, TF_OK, "");
}
// This function is only for pluggable device.
// It will return nullptr in all other cases.
// This function is experimental and subject to change.
SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"Accessing device stream is not supported on mobile. File a bug at "
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
return nullptr;
#else
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
status->status = tensorflow::errors::FailedPrecondition(
"Accessing device stream is not supported for a CPU device.");
return nullptr;
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
status->status = tensorflow::errors::FailedPrecondition(
"Accessing device stream is only supported for pluggable devices.");
return nullptr;
} else { // Is a PluggableDevice
TF_SetStatus(status, TF_OK, "");
auto c_stream = static_cast<stream_executor::CStream*>(
cc_ctx->op_device_context()->stream()->implementation());
return c_stream->Handle();
}
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
int TF_NumInputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_inputs();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
@ -65,6 +66,11 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
typedef struct TF_OpKernelContext TF_OpKernelContext;
// TF_InitKernel to do op/kernel registration.
// Plugin should implement TF_InitKernel to register kernels. This function
// should register all kernels in a plugin.
void TF_InitKernel();
// Allocates a new kernel builder and returns a pointer to it.
//
// If non-null, TensorFlow will call create_func when it needs to instantiate
@ -128,6 +134,16 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
// --------------------------------------------------------------------------
// OpKernelContext routines
// TF_GetStream returns the SP_Stream available in ctx.
// This function returns a stream only for devices registered using the
// StreamExecutor C API
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
// nullptr and set error status in all other cases.
// Experimental: this function doesn't have compatibility guarantees and subject
// to change at any time.
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx,
TF_Status* status);
// TF_NumInputs returns the number of inputs available in ctx.
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);

View File

@ -378,6 +378,23 @@ template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx);
REGISTER_OP("StreamOp").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestStream) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
SP_Stream stream = TF_GetStream(ctx, s);
// Stream is always null if device is not a pluggable device. More test
// cases will be added when pluggable device mechanism is supported.
EXPECT_EQ(stream, nullptr);
EXPECT_NE(TF_OK, TF_GetCode(s));
TF_DeleteStatus(s);
};
SetupOp("StreamOp", "StreamOp", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
}
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {

View File

@ -869,6 +869,9 @@ filegroup(
srcs = [
# Sources for which we do not yet have granular targets.
"//tensorflow/c/eager:srcs",
# StreamExecutor C API is currently not supported for mobile.
# Including just the header for SP_Stream reference in kernels C API.
"//tensorflow/c/experimental/stream_executor:headers",
"//tensorflow/c:srcs",
"//tensorflow/core/common_runtime:mobile_srcs_only_runtime",
"//tensorflow/core/common_runtime/eager:srcs",

View File

@ -114,6 +114,9 @@ class DeviceContext : public core::RefCounted {
std::function<void()> func) {
return errors::Internal("ThenExecute not supported by device");
}
// check if device is a pluggable device
virtual bool IsPluggableDevice() { return false; }
};
class DeviceBase {
@ -173,7 +176,6 @@ class DeviceBase {
// Does not take ownership.
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
// Return the Allocator implementation to use based on the allocator
// attributes requested. See allocator.h for more details.
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
@ -204,7 +206,6 @@ class DeviceBase {
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
// Caller owns the return value. The OpKernelContext calls this even
// for devices that do not implement an eigen_gpu_device. Overridden
// by GPU devices to return a derived type.