address comments

This commit is contained in:
Zhoulong Jiang 2020-10-23 15:23:53 +00:00
parent 554a5fbd72
commit 4d9db8b809
8 changed files with 43 additions and 109 deletions

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/strcat.h"
@ -747,6 +749,14 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
opts->opts.validate_colocation_constraints = enable; opts->opts.validate_colocation_constraints = enable;
} }
// Load a Pluggable Device library
// On sucess, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) { TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
@ -755,17 +765,28 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
return nullptr; return nullptr;
#else #else
TF_Library* lib_handle = new TF_Library; TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadPluggableDeviceLibrary( static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
library_filename, &lib_handle->lib_handle); static std::unordered_map<std::string, void*> loaded_libs;
if (!status->status.ok()) { tensorflow::Env* env = tensorflow::Env::Default();
delete lib_handle; {
return nullptr; tensorflow::mutex_lock lock(mu);
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
lib_handle->lib_handle = loaded_libs[library_filename];
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (status->status.ok()) {
// Init PluggableDevice Plugin
} else {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
} }
return lib_handle;
#endif #endif
} }
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
if (lib_handle == nullptr) return;
delete lib_handle; delete lib_handle;
} }

View File

@ -308,7 +308,7 @@ TF_ImportGraphDefOptionsSetValidateColocationConstraints(
// device and related kernels present in that library. // device and related kernels present in that library.
// //
// Pass "library_filename" to a platform-specific mechanism for dynamically // Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The reles for determining the exact location of the // loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here // library are platform-specific and are not documented here
// //
// On success, place OK in status and return the newly created library handle. // On success, place OK in status and return the newly created library handle.

View File

@ -60,11 +60,3 @@ tf_cc_test(
"//tensorflow/stream_executor:stream_executor_pimpl", "//tensorflow/stream_executor:stream_executor_pimpl",
], ],
) )
exports_files(
srcs = [
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//visibility:public"],
)

View File

@ -2,20 +2,12 @@
# test for stream_executor # test for stream_executor
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_custom_op_library", "tf_cc_shared_object",
) )
cc_library( tf_cc_shared_object(
name = "test_pluggable_device_impl",
srcs = ["test_pluggable_device.cc"],
hdrs = [
"//tensorflow/c:headers",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
],
)
tf_custom_op_library(
name = "test_pluggable_device.so", name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"], visibility = ["//tensorflow/c:__subpackages__"],
deps = [":test_pluggable_device_impl"], deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
) )

View File

@ -134,12 +134,13 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// OpKernelContext routines // OpKernelContext routines
// TF_GetStream returns the SP_Stream available in ctx // TF_GetStream returns the SP_Stream available in ctx.
// This function returns a stream only for devices registered using the // This function returns a stream only for devices registered using the
// StreamExecutor C API // StreamExecutor C API
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return // (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
// nullptr in all other cases. Experimental: this function doesn't have // nullptr in all other cases.
// compatibility guarantees and subject to change at any time." // Experimental: this function doesn't have compatibility guarantees and subject
// to change at any time."
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx); TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
// TF_NumInputs returns the number of inputs available in ctx. // TF_NumInputs returns the number of inputs available in ctx.

View File

@ -1,24 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -101,35 +101,4 @@ Status LoadDynamicLibrary(const char* library_filename, void** result,
return Status::OK(); return Status::OK();
} }
// Load a Pluggable Device library
// On sucess, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is laoded
// for the first time.
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) {
static mutex mu(LINKER_INITIALIZED);
static std::unordered_map<string, Library> loaded_libs;
Env* env = Env::Default();
Library library;
{
mutex_lock lock(mu);
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
library = loaded_libs[library_filename];
} else {
Status s = env->LoadDynamicLibrary(library_filename, &library.handle);
if (s.ok()) {
// Init PluggableDevice Plugin
} else {
return s;
}
loaded_libs[library_filename] = library;
}
*result = library.handle;
return Status::OK();
}
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -1244,12 +1244,12 @@ class Context(object):
def invoking_op_callbacks(self, value): def invoking_op_callbacks(self, value):
self._thread_local_data.invoking_op_callbacks = value self._thread_local_data.invoking_op_callbacks = value
def _initialize_physical_devices(self): def _initialize_physical_devices(self, reinitialize=False):
"""Get local devices visible to the system.""" """Get local devices visible to the system."""
# We lazy initialize self._physical_devices since we do not want to do this # We lazy initialize self._physical_devices since we do not want to do this
# the constructor since the backend may not be initialized yet. # the constructor since the backend may not be initialized yet.
with self._device_lock: with self._device_lock:
if self._physical_devices is not None: if reinitialize is False and self._physical_devices is not None:
return return
devs = pywrap_tfe.TF_ListPhysicalDevices() devs = pywrap_tfe.TF_ListPhysicalDevices()
@ -1267,29 +1267,12 @@ class Context(object):
# Import device settings that may have been passed into the constructor # Import device settings that may have been passed into the constructor
self._import_config() self._import_config()
def reinitialize_physical_devices(self): def reinitialize_physical_devices(self):
"""Get local devices visible to the system.""" """Get local devices visible to the system."""
# We lazy initialize self._physical_devices since we do not want to do this # Reinitialize the physical device list after registering
# the constructor since the backend may not be initialized yet. # the pluggable device.
with self._device_lock: self._initialize_physical_devices(True)
devs = pywrap_tfe.TF_ListPhysicalDevices()
self._physical_devices = [
PhysicalDevice(name=d.decode(),
device_type=d.decode().split(":")[1]) for d in devs]
self._physical_device_to_index = {
p: i for i, p in enumerate(self._physical_devices)
}
self._visible_device_list = list(self._physical_devices)
self._memory_growth_map = {
d: None for d in self._physical_devices if d.device_type == "GPU"
}
# Import device settings that may have been passed into the constructor
self._import_config()
def list_physical_devices(self, device_type=None): def list_physical_devices(self, device_type=None):
"""List local devices visible to the system. """List local devices visible to the system.