Merge pull request #43610 from Intel-tensorflow:pluggable_device_load
PiperOrigin-RevId: 346437000 Change-Id: Iaa0617de52a61b10eb9182e883b81d9274237eec
This commit is contained in:
commit
d163783b0a
@ -145,6 +145,8 @@ if _running_from_pip_package():
|
|||||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _os.path.exists(_plugin_dir):
|
if _os.path.exists(_plugin_dir):
|
||||||
_ll.load_library(_plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
# Load Pluggable Device Library
|
||||||
|
_ll.load_pluggable_device_library(_plugin_dir)
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
|
@ -155,6 +155,8 @@ if _running_from_pip_package():
|
|||||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _os.path.exists(_plugin_dir):
|
if _os.path.exists(_plugin_dir):
|
||||||
_ll.load_library(_plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
# Load Pluggable Device Library
|
||||||
|
_ll.load_pluggable_device_library(_plugin_dir)
|
||||||
|
|
||||||
# Delete modules that should be hidden from dir().
|
# Delete modules that should be hidden from dir().
|
||||||
# Don't fail if these modules are not available.
|
# Don't fail if these modules are not available.
|
||||||
|
@ -684,7 +684,10 @@ tf_cc_test(
|
|||||||
name = "c_api_experimental_test",
|
name = "c_api_experimental_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["c_api_experimental_test.cc"],
|
srcs = ["c_api_experimental_test.cc"],
|
||||||
data = ["testdata/tf_record"],
|
data = [
|
||||||
|
"testdata/tf_record",
|
||||||
|
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
|
||||||
|
],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
@ -704,6 +707,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/platform/blocking_counter.h"
|
#include "tensorflow/core/platform/blocking_counter.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/net.h"
|
#include "tensorflow/core/platform/net.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
|
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
|
||||||
|
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||||
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
|||||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||||
opts->opts.validate_colocation_constraints = enable;
|
opts->opts.validate_colocation_constraints = enable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load a Pluggable Device library.
|
||||||
|
// On success, returns the handle to library in result and return OK from the
|
||||||
|
// function. Otherwise return nullptr in result and error Status from the
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// If `library_filename` has already been loaded, we return a cached handle.
|
||||||
|
// Device and Kernels/Ops are registered as globals when a library is loaded
|
||||||
|
// for the first time.
|
||||||
|
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||||
|
TF_Status* status) {
|
||||||
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||||
|
status->status = tensorflow::errors::Unimplemented(
|
||||||
|
"PluggableDevice plugin functionality is not supported on mobile");
|
||||||
|
return nullptr;
|
||||||
|
#else
|
||||||
|
TF_Library* lib_handle = new TF_Library;
|
||||||
|
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||||
|
static std::unordered_map<std::string, void*>* loaded_libs =
|
||||||
|
new std::unordered_map<std::string, void*>();
|
||||||
|
tensorflow::Env* env = tensorflow::Env::Default();
|
||||||
|
{
|
||||||
|
tensorflow::mutex_lock lock(mu);
|
||||||
|
auto it = loaded_libs->find(library_filename);
|
||||||
|
if (it != loaded_libs->end()) {
|
||||||
|
lib_handle->lib_handle = it->second;
|
||||||
|
} else {
|
||||||
|
status->status =
|
||||||
|
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
delete lib_handle;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lib_handle;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||||
|
delete lib_handle;
|
||||||
|
}
|
||||||
|
@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
|
|||||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||||
|
|
||||||
|
// Load the library specified by library_filename and register the pluggable
|
||||||
|
// device and related kernels present in that library. This function is not
|
||||||
|
// supported on embedded on mobile and embedded platforms and will fail if
|
||||||
|
// called.
|
||||||
|
//
|
||||||
|
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||||
|
// loading a library. The rules for determining the exact location of the
|
||||||
|
// library are platform-specific and are not documented here.
|
||||||
|
//
|
||||||
|
// On success, returns the newly created library handle and places OK in status.
|
||||||
|
// The caller owns the library handle.
|
||||||
|
//
|
||||||
|
// On failure, returns nullptr and places an error status in status.
|
||||||
|
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||||
|
const char* library_filename, TF_Status* status);
|
||||||
|
|
||||||
|
// Frees the memory associated with the library handle.
|
||||||
|
// Does NOT unload the library.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||||
|
TF_Library* lib_handle);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
|||||||
TF_DeleteTensor(tensor_1X6);
|
TF_DeleteTensor(tensor_1X6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||||
|
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||||
|
// Load the library.
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
string lib_path =
|
||||||
|
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||||
|
"tensorflow", "c", "experimental", "stream_executor", "test",
|
||||||
|
"test_pluggable_device.so"));
|
||||||
|
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||||
|
TF_Code code = TF_GetCode(status);
|
||||||
|
string status_msg(TF_Message(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||||
|
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||||
|
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Description:
|
||||||
|
# test for stream_executor
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_shared_object",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_shared_object(
|
||||||
|
name = "test_pluggable_device.so",
|
||||||
|
srcs = ["test_pluggable_device.cc"],
|
||||||
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
|
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||||
|
)
|
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||||
|
|
||||||
|
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||||
|
TF_Status* const status) {
|
||||||
|
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||||
|
params->platform->name = "GPU";
|
||||||
|
params->platform->type = "XGPU";
|
||||||
|
}
|
@ -36,9 +36,9 @@ import traceback
|
|||||||
|
|
||||||
# go/tf-wildcard-import
|
# go/tf-wildcard-import
|
||||||
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
||||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||||
|
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
|
@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
|
|||||||
},
|
},
|
||||||
py::return_value_policy::reference);
|
py::return_value_policy::reference);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"TF_LoadPluggableDeviceLibrary",
|
||||||
|
[](const char* library_filename) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
auto output =
|
||||||
|
TF_LoadPluggableDeviceLibrary(library_filename, status.get());
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
return output;
|
||||||
|
},
|
||||||
|
py::return_value_policy::reference);
|
||||||
|
|
||||||
m.def("TF_GetOpList", [](TF_Library* lib_handle) {
|
m.def("TF_GetOpList", [](TF_Library* lib_handle) {
|
||||||
TF_Buffer output_buffer = TF_GetOpList(lib_handle);
|
TF_Buffer output_buffer = TF_GetOpList(lib_handle);
|
||||||
return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
|
return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
|
||||||
@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
|
|||||||
|
|
||||||
m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
|
m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
|
||||||
py::call_guard<py::gil_scoped_release>());
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
|
||||||
|
m.def("TF_PluggableDeviceLibraryHandle",
|
||||||
|
TF_DeletePluggableDeviceLibraryHandle,
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
|
||||||
m.def("TF_AddControlInput", TF_AddControlInput);
|
m.def("TF_AddControlInput", TF_AddControlInput);
|
||||||
m.def(
|
m.def(
|
||||||
"TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
|
"TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
|
||||||
|
@ -1245,12 +1245,17 @@ class Context(object):
|
|||||||
def invoking_op_callbacks(self, value):
|
def invoking_op_callbacks(self, value):
|
||||||
self._thread_local_data.invoking_op_callbacks = value
|
self._thread_local_data.invoking_op_callbacks = value
|
||||||
|
|
||||||
def _initialize_physical_devices(self):
|
def _initialize_physical_devices(self, reinitialize=False):
|
||||||
"""Get local devices visible to the system."""
|
"""Gets local devices visible to the system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reinitialize: If True, reinitializes self._physical_devices so that
|
||||||
|
dynamic registered devices will also be visible to the python front-end.
|
||||||
|
"""
|
||||||
# We lazy initialize self._physical_devices since we do not want to do this
|
# We lazy initialize self._physical_devices since we do not want to do this
|
||||||
# the constructor since the backend may not be initialized yet.
|
# the constructor since the backend may not be initialized yet.
|
||||||
with self._device_lock:
|
with self._device_lock:
|
||||||
if self._physical_devices is not None:
|
if not reinitialize and self._physical_devices is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||||
@ -1269,6 +1274,12 @@ class Context(object):
|
|||||||
# Import device settings that may have been passed into the constructor
|
# Import device settings that may have been passed into the constructor
|
||||||
self._import_config()
|
self._import_config()
|
||||||
|
|
||||||
|
def reinitialize_physical_devices(self):
|
||||||
|
"""Gets local devices visible to the system."""
|
||||||
|
# Reinitialize the physical device list after registering
|
||||||
|
# the pluggable device.
|
||||||
|
self._initialize_physical_devices(True)
|
||||||
|
|
||||||
def list_physical_devices(self, device_type=None):
|
def list_physical_devices(self, device_type=None):
|
||||||
"""List local devices visible to the system.
|
"""List local devices visible to the system.
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ import sys
|
|||||||
|
|
||||||
from tensorflow.python import _pywrap_python_op_gen
|
from tensorflow.python import _pywrap_python_op_gen
|
||||||
from tensorflow.python.client import pywrap_tf_session as py_tf
|
from tensorflow.python.client import pywrap_tf_session as py_tf
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -159,6 +160,45 @@ def load_library(library_location):
|
|||||||
library_location)
|
library_location)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pluggable_device_library(library_location):
|
||||||
|
"""Loads a TensorFlow PluggableDevice plugin.
|
||||||
|
|
||||||
|
"library_location" can be a path to a specific shared object, or a folder.
|
||||||
|
If it is a folder, all shared objects will be loaded. when the library is
|
||||||
|
loaded, devices/kernels registered in the library via StreamExecutor C API
|
||||||
|
and Kernel/Op Registration C API are made available in TensorFlow process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
library_location: Path to the plugin or folder of plugins. Relative or
|
||||||
|
absolute filesystem path to a dynamic library file or folder.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: When the file to be loaded is not found.
|
||||||
|
RuntimeError: when unable to load the library.
|
||||||
|
"""
|
||||||
|
if os.path.exists(library_location):
|
||||||
|
if os.path.isdir(library_location):
|
||||||
|
directory_contents = os.listdir(library_location)
|
||||||
|
|
||||||
|
pluggable_device_libraries = [
|
||||||
|
os.path.join(library_location, f)
|
||||||
|
for f in directory_contents
|
||||||
|
if _is_shared_object(f)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pluggable_device_libraries = [library_location]
|
||||||
|
|
||||||
|
for lib in pluggable_device_libraries:
|
||||||
|
py_tf.TF_LoadPluggableDeviceLibrary(lib)
|
||||||
|
# Reinitialized physical devices list after plugin registration.
|
||||||
|
context.context().reinitialize_physical_devices()
|
||||||
|
else:
|
||||||
|
raise OSError(
|
||||||
|
errno.ENOENT,
|
||||||
|
'The file or folder to load pluggable device libraries from does not '
|
||||||
|
'exist.', library_location)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('experimental.register_filesystem_plugin')
|
@tf_export('experimental.register_filesystem_plugin')
|
||||||
def register_filesystem_plugin(plugin_location):
|
def register_filesystem_plugin(plugin_location):
|
||||||
"""Loads a TensorFlow FileSystem plugin.
|
"""Loads a TensorFlow FileSystem plugin.
|
||||||
|
Loading…
Reference in New Issue
Block a user