Merge pull request #43610 from Intel-tensorflow:pluggable_device_load
PiperOrigin-RevId: 346437000 Change-Id: Iaa0617de52a61b10eb9182e883b81d9274237eec
This commit is contained in:
commit
d163783b0a
@ -145,6 +145,8 @@ if _running_from_pip_package():
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
# Load Pluggable Device Library
|
||||
_ll.load_pluggable_device_library(_plugin_dir)
|
||||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
|
@ -155,6 +155,8 @@ if _running_from_pip_package():
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _os.path.exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
# Load Pluggable Device Library
|
||||
_ll.load_pluggable_device_library(_plugin_dir)
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
|
@ -684,7 +684,10 @@ tf_cc_test(
|
||||
name = "c_api_experimental_test",
|
||||
size = "medium",
|
||||
srcs = ["c_api_experimental_test.cc"],
|
||||
data = ["testdata/tf_record"],
|
||||
data = [
|
||||
"testdata/tf_record",
|
||||
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
|
||||
],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
@ -704,6 +707,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -37,7 +37,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
||||
|
||||
namespace tensorflow {
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
|
||||
} // namespace tensorflow
|
||||
|
||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||
opts->opts.validate_colocation_constraints = enable;
|
||||
}
|
||||
|
||||
// Load a Pluggable Device library.
|
||||
// On success, returns the handle to library in result and return OK from the
|
||||
// function. Otherwise return nullptr in result and error Status from the
|
||||
// function.
|
||||
//
|
||||
// If `library_filename` has already been loaded, we return a cached handle.
|
||||
// Device and Kernels/Ops are registered as globals when a library is loaded
|
||||
// for the first time.
|
||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"PluggableDevice plugin functionality is not supported on mobile");
|
||||
return nullptr;
|
||||
#else
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static std::unordered_map<std::string, void*>* loaded_libs =
|
||||
new std::unordered_map<std::string, void*>();
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu);
|
||||
auto it = loaded_libs->find(library_filename);
|
||||
if (it != loaded_libs->end()) {
|
||||
lib_handle->lib_handle = it->second;
|
||||
} else {
|
||||
status->status =
|
||||
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return lib_handle;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||
delete lib_handle;
|
||||
}
|
||||
|
@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
|
||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||
|
||||
// Load the library specified by library_filename and register the pluggable
|
||||
// device and related kernels present in that library. This function is not
|
||||
// supported on embedded on mobile and embedded platforms and will fail if
|
||||
// called.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The rules for determining the exact location of the
|
||||
// library are platform-specific and are not documented here.
|
||||
//
|
||||
// On success, returns the newly created library handle and places OK in status.
|
||||
// The caller owns the library handle.
|
||||
//
|
||||
// On failure, returns nullptr and places an error status in status.
|
||||
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
|
||||
const char* library_filename, TF_Status* status);
|
||||
|
||||
// Frees the memory associated with the library handle.
|
||||
// Does NOT unload the library.
|
||||
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
|
||||
TF_Library* lib_handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
TF_DeleteTensor(tensor_1X6);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
string lib_path =
|
||||
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||
"tensorflow", "c", "experimental", "stream_executor", "test",
|
||||
"test_pluggable_device.so"));
|
||||
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
17
tensorflow/c/experimental/stream_executor/test/BUILD
Normal file
@ -0,0 +1,17 @@
|
||||
# Description:
|
||||
# test for stream_executor
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_shared_object",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "test_pluggable_device.so",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||
)
|
@ -0,0 +1,23 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) {
|
||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
params->platform->name = "GPU";
|
||||
params->platform->type = "XGPU";
|
||||
}
|
@ -36,9 +36,9 @@ import traceback
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
|
@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def(
|
||||
"TF_LoadPluggableDeviceLibrary",
|
||||
[](const char* library_filename) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
auto output =
|
||||
TF_LoadPluggableDeviceLibrary(library_filename, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def("TF_GetOpList", [](TF_Library* lib_handle) {
|
||||
TF_Buffer output_buffer = TF_GetOpList(lib_handle);
|
||||
return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
|
||||
@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) {
|
||||
|
||||
m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("TF_PluggableDeviceLibraryHandle",
|
||||
TF_DeletePluggableDeviceLibraryHandle,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("TF_AddControlInput", TF_AddControlInput);
|
||||
m.def(
|
||||
"TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
|
||||
|
@ -1245,12 +1245,17 @@ class Context(object):
|
||||
def invoking_op_callbacks(self, value):
|
||||
self._thread_local_data.invoking_op_callbacks = value
|
||||
|
||||
def _initialize_physical_devices(self):
|
||||
"""Get local devices visible to the system."""
|
||||
def _initialize_physical_devices(self, reinitialize=False):
|
||||
"""Gets local devices visible to the system.
|
||||
|
||||
Args:
|
||||
reinitialize: If True, reinitializes self._physical_devices so that
|
||||
dynamic registered devices will also be visible to the python front-end.
|
||||
"""
|
||||
# We lazy initialize self._physical_devices since we do not want to do this
|
||||
# the constructor since the backend may not be initialized yet.
|
||||
with self._device_lock:
|
||||
if self._physical_devices is not None:
|
||||
if not reinitialize and self._physical_devices is not None:
|
||||
return
|
||||
|
||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||
@ -1269,6 +1274,12 @@ class Context(object):
|
||||
# Import device settings that may have been passed into the constructor
|
||||
self._import_config()
|
||||
|
||||
def reinitialize_physical_devices(self):
|
||||
"""Gets local devices visible to the system."""
|
||||
# Reinitialize the physical device list after registering
|
||||
# the pluggable device.
|
||||
self._initialize_physical_devices(True)
|
||||
|
||||
def list_physical_devices(self, device_type=None):
|
||||
"""List local devices visible to the system.
|
||||
|
||||
|
@ -27,6 +27,7 @@ import sys
|
||||
|
||||
from tensorflow.python import _pywrap_python_op_gen
|
||||
from tensorflow.python.client import pywrap_tf_session as py_tf
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -159,6 +160,45 @@ def load_library(library_location):
|
||||
library_location)
|
||||
|
||||
|
||||
def load_pluggable_device_library(library_location):
|
||||
"""Loads a TensorFlow PluggableDevice plugin.
|
||||
|
||||
"library_location" can be a path to a specific shared object, or a folder.
|
||||
If it is a folder, all shared objects will be loaded. when the library is
|
||||
loaded, devices/kernels registered in the library via StreamExecutor C API
|
||||
and Kernel/Op Registration C API are made available in TensorFlow process.
|
||||
|
||||
Args:
|
||||
library_location: Path to the plugin or folder of plugins. Relative or
|
||||
absolute filesystem path to a dynamic library file or folder.
|
||||
|
||||
Raises:
|
||||
OSError: When the file to be loaded is not found.
|
||||
RuntimeError: when unable to load the library.
|
||||
"""
|
||||
if os.path.exists(library_location):
|
||||
if os.path.isdir(library_location):
|
||||
directory_contents = os.listdir(library_location)
|
||||
|
||||
pluggable_device_libraries = [
|
||||
os.path.join(library_location, f)
|
||||
for f in directory_contents
|
||||
if _is_shared_object(f)
|
||||
]
|
||||
else:
|
||||
pluggable_device_libraries = [library_location]
|
||||
|
||||
for lib in pluggable_device_libraries:
|
||||
py_tf.TF_LoadPluggableDeviceLibrary(lib)
|
||||
# Reinitialized physical devices list after plugin registration.
|
||||
context.context().reinitialize_physical_devices()
|
||||
else:
|
||||
raise OSError(
|
||||
errno.ENOENT,
|
||||
'The file or folder to load pluggable device libraries from does not '
|
||||
'exist.', library_location)
|
||||
|
||||
|
||||
@tf_export('experimental.register_filesystem_plugin')
|
||||
def register_filesystem_plugin(plugin_location):
|
||||
"""Loads a TensorFlow FileSystem plugin.
|
||||
|
Loading…
Reference in New Issue
Block a user