[PluggableDevice] add load mechanism and refactor streamexecutor C API

This commit is contained in:
Zhoulong Jiang 2020-09-28 09:04:52 +00:00
parent 16a59707b6
commit f599aed5d8
20 changed files with 370 additions and 120 deletions

View File

@ -145,6 +145,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Add module aliases
if hasattr(_current_module, 'keras'):

View File

@ -155,6 +155,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.

View File

@ -492,6 +492,8 @@ tf_cuda_library(
],
hdrs = [
"kernels.h",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
"//tensorflow/c/experimental/stream_executor:stream_executor_internal.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
@ -499,6 +501,7 @@ tf_cuda_library(
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/c/experimental/stream_executor:stream_executor",
] + select({
"//tensorflow:android": [
":c_api_internal",
@ -578,6 +581,7 @@ tf_cuda_cc_test(
srcs = ["c_api_test.cc"],
data = [
":test_op1.so",
":test_pluggable_device.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
linkopts = select({
@ -688,6 +692,19 @@ tf_custom_op_library(
srcs = ["test_op1.cc"],
)
cc_library(
name = "test_pluggable_device_impl",
srcs = ["test_pluggable_device.cc"],
hdrs = ["c_api_macros.h",
"tf_status.h",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",],
)
tf_custom_op_library(
name = "test_pluggable_device.so",
deps = [":test_pluggable_device_impl"],
)
tf_kernel_library(
name = "test_op_kernel",
srcs = ["test_op.cc"],

View File

@ -308,6 +308,9 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
Status LoadDynamicLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file)
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
@ -581,6 +584,23 @@ TF_Buffer* TF_GetAllOpList() {
return ret;
}
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadPluggableDeviceLibrary(
library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
return lib_handle;
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
if (lib_handle == nullptr) return;
delete lib_handle;
}
// --------------------------------------------------------------------------
// ListDevices & SessionListDevices API

View File

@ -1449,7 +1449,6 @@ typedef struct TF_Library TF_Library;
// On failure, place an error status in status and return NULL.
TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename,
TF_Status* status);
// Get the OpList of OpDefs defined in the library pointed by lib_handle.
//
// Returns a TF_Buffer. The memory pointed to by the result is owned by
@ -1461,6 +1460,24 @@ TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The reles for determining the exact location of the
// library are platform-specific and are not documented here
//
// On success, place OK in status and return the newly created library handle.
// The caller owns the library handle
//
// On failure, place an error status in status and return NULL.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
// Get the OpList of all OpDefs defined in this address space.
// Returns a TF_Buffer, ownership of which is transferred to the caller
// (and can be freed using TF_DeleteBuffer).

View File

@ -228,6 +228,21 @@ TEST(CAPI, LibraryLoadFunctions) {
}
}
TEST(CAPI, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "c", "test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size();
Status status;
@ -1467,6 +1482,7 @@ TEST(CAPI, DeletingNullPointerIsSafe) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteDeviceList(nullptr);
TF_DeleteLibraryHandle(nullptr);
TF_DeletePluggableDeviceLibraryHandle(nullptr);
TF_DeleteApiDefMap(nullptr);
TF_DeleteStatus(status);

View File

@ -59,3 +59,12 @@ tf_cc_test(
"//tensorflow/stream_executor:stream_executor_pimpl",
],
)
exports_files(
srcs = [
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//visibility:public"],
)

View File

@ -24,7 +24,6 @@ limitations under the License.
#include <string>
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
@ -158,41 +157,6 @@ port::Status ValidateSEPlatformRegistrationParams(
}
#undef VALIDATE_MEMBER
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
@ -207,82 +171,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
}
}
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};

View File

@ -19,10 +19,12 @@ limitations under the License.
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h"
using tensorflow::StatusFromTF_Status;
namespace stream_executor {
// Plugin initialization function that a device plugin
@ -37,6 +39,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
// testing).
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
namespace {
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
} // namespace
class CPlatform : public Platform {
public:
explicit CPlatform(SP_Platform platform,
@ -83,5 +92,111 @@ class CPlatform : public Platform {
stream_executor::ExecutorCache executor_cache_;
};
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
} // namespace stream_executor
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/stream.h"
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
@ -168,6 +170,15 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_SetStatus(status, TF_OK, "");
}
// This function is only for pluggable device
// it will return nullptr in all other cases
SP_Stream TF_GetStream(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
auto c_stream = static_cast<stream_executor::CStream*>(
cc_ctx->op_device_context()->stream()->implementation());
return c_stream->Handle();
}
int TF_NumInputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_inputs();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
@ -65,6 +66,10 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
typedef struct TF_OpKernelContext TF_OpKernelContext;
// TF_InitKernel to do op/kernel registration.
// Plugin needs to implement this function to register all kernels.
void TF_InitKernel();
// Allocates a new kernel builder and returns a pointer to it.
//
// If non-null, TensorFlow will call create_func when it needs to instantiate
@ -128,6 +133,11 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
// --------------------------------------------------------------------------
// OpKernelContext routines
// TF_GetStream returns the SP_Stream available in ctx
// This function is only for pluggable device
// it will return nullptr in all other cased.
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
// TF_NumInputs returns the number of inputs available in ctx.
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -101,4 +101,35 @@ Status LoadDynamicLibrary(const char* library_filename, void** result,
return Status::OK();
}
// Load a Pluggable Device library
// On sucess, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is laoded
// for the first time.
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) {
static mutex mu(LINKER_INITIALIZED);
static std::unordered_map<string, Library> loaded_libs;
Env* env = Env::Default();
Library library;
{
mutex_lock lock(mu);
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
library = loaded_libs[library_filename];
} else {
Status s = env->LoadDynamicLibrary(library_filename, &library.handle);
if (s.ok()) {
// Init PluggableDevice Plugin
} else {
return s;
}
loaded_libs[library_filename] = library;
}
*result = library.handle;
return Status::OK();
}
}
} // namespace tensorflow

View File

@ -5976,11 +5976,10 @@ py_library(
name = "pywrap_tensorflow",
srcs = [
"pywrap_tensorflow.py",
] + if_static(
["pywrap_dlopen_global_flags.py"],
# Import will fail, indicating no global dlopen flags
otherwise = [],
), # b/153585257
# modular TF need this file to load expose C API symbols
# in pywrap_tensorflow_internal.so
"pywrap_dlopen_global_flags.py",
],
srcs_version = "PY2AND3",
deps = [":pywrap_tensorflow_internal"],
)

View File

@ -36,9 +36,9 @@ import traceback
# go/tf-wildcard-import
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
# pylint: enable=wildcard-import

View File

@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
},
py::return_value_policy::reference);
m.def(
"TF_LoadPluggableDeviceLibrary",
[](const char* library_filename) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TF_LoadPluggableDeviceLibrary(library_filename, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TF_GetOpList", [](TF_Library* lib_handle) {
TF_Buffer output_buffer = TF_GetOpList(lib_handle);
return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
py::call_guard<py::gil_scoped_release>());
m.def("TF_PluggableDeviceLibraryHandle",
TF_DeletePluggableDeviceLibraryHandle,
py::call_guard<py::gil_scoped_release>());
m.def("TF_AddControlInput", TF_AddControlInput);
m.def(
"TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {

View File

@ -1268,6 +1268,29 @@ class Context(object):
# Import device settings that may have been passed into the constructor
self._import_config()
def reinitialize_physical_devices(self):
"""Get local devices visible to the system."""
# We lazy initialize self._physical_devices since we do not want to do this
# the constructor since the backend may not be initialized yet.
with self._device_lock:
devs = pywrap_tfe.TF_ListPhysicalDevices()
self._physical_devices = [
PhysicalDevice(name=d.decode(),
device_type=d.decode().split(":")[1]) for d in devs]
self._physical_device_to_index = {
p: i for i, p in enumerate(self._physical_devices)
}
self._visible_device_list = list(self._physical_devices)
self._memory_growth_map = {
d: None for d in self._physical_devices if d.device_type == "GPU"
}
# Import device settings that may have been passed into the constructor
self._import_config()
def list_physical_devices(self, device_type=None):
"""List local devices visible to the system.

View File

@ -29,7 +29,7 @@ from tensorflow.python import _pywrap_python_op_gen
from tensorflow.python.client import pywrap_tf_session as py_tf
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.eager import context
@tf_export('load_op_library')
def load_op_library(library_filename):
@ -157,3 +157,44 @@ def load_library(library_location):
errno.ENOENT,
'The file or folder to load kernel libraries from does not exist.',
library_location)
@tf_export('load_pluggable_device_library')
def load_pluggable_device_library(library_location):
"""Loads a Tensorflow PluggableDevice plugin
"library_location" can be a path to a specific shared object, or a folder.
If it is a folder, all shared objects will be loaded. when the library is
loaded, devices/kernels registered in the library via StreamExecutor C API
and Kernel/Op Registration C API are made available in TensorFlow process
Args:
library_location: Path to the plugin or folder of plugins.
Relative or absolute filesystem path to a dynamic library file or folder.
Returns:
None
Raises:
OSError: When the file to be loaded is not found.
RuntimeError: when unable to load the library.
"""
if file_io.file_exists(library_location):
if file_io.is_directory(library_location):
directory_contents = file_io.list_directory(library_location)
pluggable_device_libraries = [
os.path.join(library_location, f) for f in directory_contents
if _is_shared_object(f)]
else:
pluggable_device_libraries = [library_location]
for lib in pluggable_device_libraries:
py_tf.TF_LoadPluggableDeviceLibrary(lib)
# Reinitialized physical devices list after plugin registration
context.context().reinitialize_physical_devices()
else:
raise OSError(
errno.ENOENT,
'The file or folder to load pluggable device libraries from does not exist.',
library_location)

View File

@ -1576,6 +1576,10 @@ tf_module {
name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "load_pluggable_device_library"
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "local_variables"
argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -772,6 +772,10 @@ tf_module {
name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "load_pluggable_device_library"
argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "logical_and"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "