[PluggableDevice] add load mechanism and refactor streamexecutor C API
This commit is contained in:
parent
16a59707b6
commit
f599aed5d8
@ -145,6 +145,8 @@ if _running_from_pip_package():
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
# Load Pluggable Device Library
|
||||
_ll.load_pluggable_device_library(_plugin_dir)
|
||||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
|
@ -155,6 +155,8 @@ if _running_from_pip_package():
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
# Load Pluggable Device Library
|
||||
_ll.load_pluggable_device_library(_plugin_dir)
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
|
@ -492,6 +492,8 @@ tf_cuda_library(
|
||||
],
|
||||
hdrs = [
|
||||
"kernels.h",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_internal.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
@ -499,6 +501,7 @@ tf_cuda_library(
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
@ -578,6 +581,7 @@ tf_cuda_cc_test(
|
||||
srcs = ["c_api_test.cc"],
|
||||
data = [
|
||||
":test_op1.so",
|
||||
":test_pluggable_device.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
linkopts = select({
|
||||
@ -688,6 +692,19 @@ tf_custom_op_library(
|
||||
srcs = ["test_op1.cc"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_pluggable_device_impl",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
hdrs = ["c_api_macros.h",
|
||||
"tf_status.h",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "test_pluggable_device.so",
|
||||
deps = [":test_pluggable_device_impl"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "test_op_kernel",
|
||||
srcs = ["test_op.cc"],
|
||||
|
@ -308,6 +308,9 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
|
||||
Status LoadDynamicLibrary(const char* library_filename, void** result,
|
||||
const void** buf, size_t* len);
|
||||
|
||||
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file)
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||
|
||||
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
|
||||
// directly, instead of requiring us to serialize to a GraphDef and
|
||||
// call Session::Extend().
|
||||
@ -581,6 +584,23 @@ TF_Buffer* TF_GetAllOpList() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
TF_Status* status) {
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
||||
library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
return lib_handle;
|
||||
}
|
||||
|
||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||
if (lib_handle == nullptr) return;
|
||||
delete lib_handle;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// ListDevices & SessionListDevices API
|
||||
|
||||
|
@ -1449,7 +1449,6 @@ typedef struct TF_Library TF_Library;
|
||||
// On failure, place an error status in status and return NULL.
|
||||
TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename,
|
||||
TF_Status* status);
|
||||
|
||||
// Get the OpList of OpDefs defined in the library pointed by lib_handle.
|
||||
//
|
||||
// Returns a TF_Buffer. The memory pointed to by the result is owned by
|
||||
@ -1461,6 +1460,24 @@ TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
|
||||
// Does NOT unload the library.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
|
||||
|
||||
// Load the library specified by library_filename and register the pluggable
|
||||
// device and related kernels present in that library.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The reles for determining the exact location of the
|
||||
// library are platform-specific and are not documented here
|
||||
//
|
||||
// On success, place OK in status and return the newly created library handle.
|
||||
// The caller owns the library handle
|
||||
//
|
||||
// On failure, place an error status in status and return NULL.
|
||||
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||
const char* library_filename, TF_Status* status);
|
||||
|
||||
// Frees the memory associated with the library handle.
|
||||
// Does NOT unload the library.
|
||||
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||
TF_Library* lib_handle);
|
||||
// Get the OpList of all OpDefs defined in this address space.
|
||||
// Returns a TF_Buffer, ownership of which is transferred to the caller
|
||||
// (and can be freed using TF_DeleteBuffer).
|
||||
|
@ -228,6 +228,21 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CAPI, LibraryPluggableDeviceLoadFunctions) {
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
string lib_path = tensorflow::GetDataDependencyFilepath(
|
||||
tensorflow::io::JoinPath("tensorflow", "c", "test_pluggable_device.so"));
|
||||
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
}
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
const tensorflow::int64 n = data.size();
|
||||
Status status;
|
||||
@ -1467,6 +1482,7 @@ TEST(CAPI, DeletingNullPointerIsSafe) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteDeviceList(nullptr);
|
||||
TF_DeleteLibraryHandle(nullptr);
|
||||
TF_DeletePluggableDeviceLibraryHandle(nullptr);
|
||||
TF_DeleteApiDefMap(nullptr);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
@ -59,3 +59,12 @@ tf_cc_test(
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
srcs = [
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -158,41 +157,6 @@ port::Status ValidateSEPlatformRegistrationParams(
|
||||
}
|
||||
#undef VALIDATE_MEMBER
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
// Converts SE_EventStatus to Event::Status.
|
||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
switch (s) {
|
||||
@ -207,82 +171,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
}
|
||||
}
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
// Converts DeviceMemoryBase to a C struct.
|
||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
|
@ -19,10 +19,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
using tensorflow::StatusFromTF_Status;
|
||||
namespace stream_executor {
|
||||
|
||||
// Plugin initialization function that a device plugin
|
||||
@ -37,6 +39,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
|
||||
// testing).
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||
|
||||
namespace {
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
} // namespace
|
||||
|
||||
class CPlatform : public Platform {
|
||||
public:
|
||||
explicit CPlatform(SP_Platform platform,
|
||||
@ -83,5 +92,111 @@ class CPlatform : public Platform {
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
};
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
@ -25,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
// implementations. It is crucial that changes to this file are made cautiously
|
||||
@ -168,6 +170,15 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// This function is only for pluggable device
|
||||
// it will return nullptr in all other cases
|
||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
}
|
||||
|
||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
return cc_ctx->num_inputs();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
@ -65,6 +66,10 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
|
||||
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
||||
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
||||
|
||||
// TF_InitKernel to do op/kernel registration.
|
||||
// Plugin needs to implement this function to register all kernels.
|
||||
void TF_InitKernel();
|
||||
|
||||
// Allocates a new kernel builder and returns a pointer to it.
|
||||
//
|
||||
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||
@ -128,6 +133,11 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||
// --------------------------------------------------------------------------
|
||||
// OpKernelContext routines
|
||||
|
||||
// TF_GetStream returns the SP_Stream available in ctx
|
||||
// This function is only for pluggable device
|
||||
// it will return nullptr in all other cased.
|
||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
|
||||
|
||||
// TF_NumInputs returns the number of inputs available in ctx.
|
||||
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||
|
||||
|
24
tensorflow/c/test_pluggable_device.cc
Normal file
24
tensorflow/c/test_pluggable_device.cc
Normal file
@ -0,0 +1,24 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) {
|
||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
params->platform->name = "GPU";
|
||||
params->platform->type = "XGPU";
|
||||
}
|
@ -101,4 +101,35 @@ Status LoadDynamicLibrary(const char* library_filename, void** result,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load a Pluggable Device library
|
||||
// On sucess, returns the handle to library in result and return OK from the
|
||||
// function. Otherwise return nullptr in result and error Status from the
|
||||
// function.
|
||||
//
|
||||
// If `library_filename` has already been loaded, we return a cached handle.
|
||||
// Device and Kernels/Ops are registered as globals when a library is laoded
|
||||
// for the first time.
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) {
|
||||
static mutex mu(LINKER_INITIALIZED);
|
||||
static std::unordered_map<string, Library> loaded_libs;
|
||||
Env* env = Env::Default();
|
||||
Library library;
|
||||
{
|
||||
mutex_lock lock(mu);
|
||||
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
|
||||
library = loaded_libs[library_filename];
|
||||
} else {
|
||||
Status s = env->LoadDynamicLibrary(library_filename, &library.handle);
|
||||
if (s.ok()) {
|
||||
// Init PluggableDevice Plugin
|
||||
} else {
|
||||
return s;
|
||||
}
|
||||
|
||||
loaded_libs[library_filename] = library;
|
||||
}
|
||||
*result = library.handle;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -5976,11 +5976,10 @@ py_library(
|
||||
name = "pywrap_tensorflow",
|
||||
srcs = [
|
||||
"pywrap_tensorflow.py",
|
||||
] + if_static(
|
||||
["pywrap_dlopen_global_flags.py"],
|
||||
# Import will fail, indicating no global dlopen flags
|
||||
otherwise = [],
|
||||
), # b/153585257
|
||||
# modular TF need this file to load expose C API symbols
|
||||
# in pywrap_tensorflow_internal.so
|
||||
"pywrap_dlopen_global_flags.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":pywrap_tensorflow_internal"],
|
||||
)
|
||||
|
@ -36,9 +36,9 @@ import traceback
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
|
@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def(
|
||||
"TF_LoadPluggableDeviceLibrary",
|
||||
[](const char* library_filename) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
auto output =
|
||||
TF_LoadPluggableDeviceLibrary(library_filename, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def("TF_GetOpList", [](TF_Library* lib_handle) {
|
||||
TF_Buffer output_buffer = TF_GetOpList(lib_handle);
|
||||
return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
|
||||
@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
|
||||
|
||||
m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("TF_PluggableDeviceLibraryHandle",
|
||||
TF_DeletePluggableDeviceLibraryHandle,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("TF_AddControlInput", TF_AddControlInput);
|
||||
m.def(
|
||||
"TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
|
||||
|
@ -1268,6 +1268,29 @@ class Context(object):
|
||||
# Import device settings that may have been passed into the constructor
|
||||
self._import_config()
|
||||
|
||||
def reinitialize_physical_devices(self):
|
||||
"""Get local devices visible to the system."""
|
||||
# We lazy initialize self._physical_devices since we do not want to do this
|
||||
# the constructor since the backend may not be initialized yet.
|
||||
with self._device_lock:
|
||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||
self._physical_devices = [
|
||||
PhysicalDevice(name=d.decode(),
|
||||
device_type=d.decode().split(":")[1]) for d in devs]
|
||||
self._physical_device_to_index = {
|
||||
p: i for i, p in enumerate(self._physical_devices)
|
||||
}
|
||||
|
||||
self._visible_device_list = list(self._physical_devices)
|
||||
self._memory_growth_map = {
|
||||
d: None for d in self._physical_devices if d.device_type == "GPU"
|
||||
}
|
||||
|
||||
# Import device settings that may have been passed into the constructor
|
||||
self._import_config()
|
||||
|
||||
|
||||
|
||||
def list_physical_devices(self, device_type=None):
|
||||
"""List local devices visible to the system.
|
||||
|
||||
|
@ -29,7 +29,7 @@ from tensorflow.python import _pywrap_python_op_gen
|
||||
from tensorflow.python.client import pywrap_tf_session as py_tf
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
|
||||
@tf_export('load_op_library')
|
||||
def load_op_library(library_filename):
|
||||
@ -157,3 +157,44 @@ def load_library(library_location):
|
||||
errno.ENOENT,
|
||||
'The file or folder to load kernel libraries from does not exist.',
|
||||
library_location)
|
||||
|
||||
@tf_export('load_pluggable_device_library')
|
||||
def load_pluggable_device_library(library_location):
|
||||
"""Loads a Tensorflow PluggableDevice plugin
|
||||
"library_location" can be a path to a specific shared object, or a folder.
|
||||
If it is a folder, all shared objects will be loaded. when the library is
|
||||
loaded, devices/kernels registered in the library via StreamExecutor C API
|
||||
and Kernel/Op Registration C API are made available in TensorFlow process
|
||||
|
||||
Args:
|
||||
library_location: Path to the plugin or folder of plugins.
|
||||
Relative or absolute filesystem path to a dynamic library file or folder.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
OSError: When the file to be loaded is not found.
|
||||
RuntimeError: when unable to load the library.
|
||||
"""
|
||||
if file_io.file_exists(library_location):
|
||||
if file_io.is_directory(library_location):
|
||||
directory_contents = file_io.list_directory(library_location)
|
||||
|
||||
pluggable_device_libraries = [
|
||||
os.path.join(library_location, f) for f in directory_contents
|
||||
if _is_shared_object(f)]
|
||||
else:
|
||||
pluggable_device_libraries = [library_location]
|
||||
|
||||
for lib in pluggable_device_libraries:
|
||||
py_tf.TF_LoadPluggableDeviceLibrary(lib)
|
||||
# Reinitialized physical devices list after plugin registration
|
||||
context.context().reinitialize_physical_devices()
|
||||
else:
|
||||
raise OSError(
|
||||
errno.ENOENT,
|
||||
'The file or folder to load pluggable device libraries from does not exist.',
|
||||
library_location)
|
||||
|
||||
|
||||
|
@ -1576,6 +1576,10 @@ tf_module {
|
||||
name: "load_op_library"
|
||||
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "load_pluggable_device_library"
|
||||
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "local_variables"
|
||||
argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -772,6 +772,10 @@ tf_module {
|
||||
name: "load_op_library"
|
||||
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "load_pluggable_device_library"
|
||||
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "logical_and"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user