Merge pull request #43610 from Intel-tensorflow:pluggable_device_load

PiperOrigin-RevId: 346437000
Change-Id: Iaa0617de52a61b10eb9182e883b81d9274237eec
This commit is contained in:
TensorFlower Gardener 2020-12-08 16:47:36 -08:00
commit d163783b0a
12 changed files with 207 additions and 5 deletions

View File

@ -145,6 +145,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins') _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir): if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Add module aliases # Add module aliases
if hasattr(_current_module, 'keras'): if hasattr(_current_module, 'keras'):

View File

@ -155,6 +155,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins') _plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir): if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir) _ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Delete modules that should be hidden from dir(). # Delete modules that should be hidden from dir().
# Don't fail if these modules are not available. # Don't fail if these modules are not available.

View File

@ -684,7 +684,10 @@ tf_cc_test(
name = "c_api_experimental_test", name = "c_api_experimental_test",
size = "medium", size = "medium",
srcs = ["c_api_experimental_test.cc"], srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"], data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
],
linkopts = select({ linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"], "//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [], "//conditions:default": [],
@ -704,6 +707,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/strcat.h"
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
namespace tensorflow { namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
} // namespace tensorflow } // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) { TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable; opts->opts.validate_colocation_constraints = enable;
} }
// Load a Pluggable Device library.
// On success, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"PluggableDevice plugin functionality is not supported on mobile");
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::string, void*>* loaded_libs =
new std::unordered_map<std::string, void*>();
tensorflow::Env* env = tensorflow::Env::Default();
{
tensorflow::mutex_lock lock(mu);
auto it = loaded_libs->find(library_filename);
if (it != loaded_libs->end()) {
lib_handle->lib_handle = it->second;
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
}
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
delete lib_handle;
}

View File

@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable); TF_ImportGraphDefOptions* opts, unsigned char enable);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library. This function is not
// supported on embedded on mobile and embedded platforms and will fail if
// called.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, returns the newly created library handle and places OK in status.
// The caller owns the library handle.
//
// On failure, returns nullptr and places an error status in status.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TF_DeleteTensor(tensor_1X6); TF_DeleteTensor(tensor_1X6);
} }
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path =
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "c", "experimental", "stream_executor", "test",
"test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,17 @@
# Description:
# test for stream_executor
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_shared_object",
)
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_shared_object(
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
)

View File

@ -0,0 +1,23 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -36,9 +36,9 @@ import traceback
# go/tf-wildcard-import # go/tf-wildcard-import
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
# pylint: enable=wildcard-import # pylint: enable=wildcard-import

View File

@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
m.def(
"TF_LoadPluggableDeviceLibrary",
[](const char* library_filename) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TF_LoadPluggableDeviceLibrary(library_filename, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TF_GetOpList", [](TF_Library* lib_handle) { m.def("TF_GetOpList", [](TF_Library* lib_handle) {
TF_Buffer output_buffer = TF_GetOpList(lib_handle); TF_Buffer output_buffer = TF_GetOpList(lib_handle);
return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize( return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle, m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("TF_PluggableDeviceLibraryHandle",
TF_DeletePluggableDeviceLibraryHandle,
py::call_guard<py::gil_scoped_release>());
m.def("TF_AddControlInput", TF_AddControlInput); m.def("TF_AddControlInput", TF_AddControlInput);
m.def( m.def(
"TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) { "TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {

View File

@ -1245,12 +1245,17 @@ class Context(object):
def invoking_op_callbacks(self, value): def invoking_op_callbacks(self, value):
self._thread_local_data.invoking_op_callbacks = value self._thread_local_data.invoking_op_callbacks = value
def _initialize_physical_devices(self): def _initialize_physical_devices(self, reinitialize=False):
"""Get local devices visible to the system.""" """Gets local devices visible to the system.
Args:
reinitialize: If True, reinitializes self._physical_devices so that
dynamic registered devices will also be visible to the python front-end.
"""
# We lazy initialize self._physical_devices since we do not want to do this # We lazy initialize self._physical_devices since we do not want to do this
# the constructor since the backend may not be initialized yet. # the constructor since the backend may not be initialized yet.
with self._device_lock: with self._device_lock:
if self._physical_devices is not None: if not reinitialize and self._physical_devices is not None:
return return
devs = pywrap_tfe.TF_ListPhysicalDevices() devs = pywrap_tfe.TF_ListPhysicalDevices()
@ -1269,6 +1274,12 @@ class Context(object):
# Import device settings that may have been passed into the constructor # Import device settings that may have been passed into the constructor
self._import_config() self._import_config()
def reinitialize_physical_devices(self):
"""Gets local devices visible to the system."""
# Reinitialize the physical device list after registering
# the pluggable device.
self._initialize_physical_devices(True)
def list_physical_devices(self, device_type=None): def list_physical_devices(self, device_type=None):
"""List local devices visible to the system. """List local devices visible to the system.

View File

@ -27,6 +27,7 @@ import sys
from tensorflow.python import _pywrap_python_op_gen from tensorflow.python import _pywrap_python_op_gen
from tensorflow.python.client import pywrap_tf_session as py_tf from tensorflow.python.client import pywrap_tf_session as py_tf
from tensorflow.python.eager import context
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -159,6 +160,45 @@ def load_library(library_location):
library_location) library_location)
def load_pluggable_device_library(library_location):
"""Loads a TensorFlow PluggableDevice plugin.
"library_location" can be a path to a specific shared object, or a folder.
If it is a folder, all shared objects will be loaded. when the library is
loaded, devices/kernels registered in the library via StreamExecutor C API
and Kernel/Op Registration C API are made available in TensorFlow process.
Args:
library_location: Path to the plugin or folder of plugins. Relative or
absolute filesystem path to a dynamic library file or folder.
Raises:
OSError: When the file to be loaded is not found.
RuntimeError: when unable to load the library.
"""
if os.path.exists(library_location):
if os.path.isdir(library_location):
directory_contents = os.listdir(library_location)
pluggable_device_libraries = [
os.path.join(library_location, f)
for f in directory_contents
if _is_shared_object(f)
]
else:
pluggable_device_libraries = [library_location]
for lib in pluggable_device_libraries:
py_tf.TF_LoadPluggableDeviceLibrary(lib)
# Reinitialized physical devices list after plugin registration.
context.context().reinitialize_physical_devices()
else:
raise OSError(
errno.ENOENT,
'The file or folder to load pluggable device libraries from does not '
'exist.', library_location)
@tf_export('experimental.register_filesystem_plugin') @tf_export('experimental.register_filesystem_plugin')
def register_filesystem_plugin(plugin_location): def register_filesystem_plugin(plugin_location):
"""Loads a TensorFlow FileSystem plugin. """Loads a TensorFlow FileSystem plugin.