address comments

This commit is contained in:
Zhoulong Jiang 2020-10-23 15:23:53 +00:00
parent 554a5fbd72
commit 4d9db8b809
8 changed files with 43 additions and 109 deletions

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h"
@ -747,6 +749,14 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
opts->opts.validate_colocation_constraints = enable;
}
// Load a Pluggable Device library
// On sucess, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
@ -755,17 +765,28 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadPluggableDeviceLibrary(
library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::string, void*> loaded_libs;
tensorflow::Env* env = tensorflow::Env::Default();
{
tensorflow::mutex_lock lock(mu);
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
lib_handle->lib_handle = loaded_libs[library_filename];
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (status->status.ok()) {
// Init PluggableDevice Plugin
} else {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
}
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
if (lib_handle == nullptr) return;
delete lib_handle;
}

View File

@ -308,7 +308,7 @@ TF_ImportGraphDefOptionsSetValidateColocationConstraints(
// device and related kernels present in that library.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The reles for determining the exact location of the
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here
//
// On success, place OK in status and return the newly created library handle.

View File

@ -60,11 +60,3 @@ tf_cc_test(
"//tensorflow/stream_executor:stream_executor_pimpl",
],
)
exports_files(
srcs = [
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//visibility:public"],
)

View File

@ -2,20 +2,12 @@
# test for stream_executor
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_cc_shared_object",
)
cc_library(
name = "test_pluggable_device_impl",
srcs = ["test_pluggable_device.cc"],
hdrs = [
"//tensorflow/c:headers",
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
],
)
tf_custom_op_library(
tf_cc_shared_object(
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
deps = [":test_pluggable_device_impl"],
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
)

View File

@ -134,12 +134,13 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
// --------------------------------------------------------------------------
// OpKernelContext routines
// TF_GetStream returns the SP_Stream available in ctx
// TF_GetStream returns the SP_Stream available in ctx.
// This function returns a stream only for devices registered using the
// StreamExecutor C API
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
// nullptr in all other cases. Experimental: this function doesn't have
// compatibility guarantees and subject to change at any time."
// nullptr in all other cases.
// Experimental: this function doesn't have compatibility guarantees and subject
// to change at any time."
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
// TF_NumInputs returns the number of inputs available in ctx.

View File

@ -1,24 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -101,35 +101,4 @@ Status LoadDynamicLibrary(const char* library_filename, void** result,
return Status::OK();
}
// Load a Pluggable Device library
// On sucess, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is laoded
// for the first time.
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) {
static mutex mu(LINKER_INITIALIZED);
static std::unordered_map<string, Library> loaded_libs;
Env* env = Env::Default();
Library library;
{
mutex_lock lock(mu);
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
library = loaded_libs[library_filename];
} else {
Status s = env->LoadDynamicLibrary(library_filename, &library.handle);
if (s.ok()) {
// Init PluggableDevice Plugin
} else {
return s;
}
loaded_libs[library_filename] = library;
}
*result = library.handle;
return Status::OK();
}
}
} // namespace tensorflow

View File

@ -1244,12 +1244,12 @@ class Context(object):
def invoking_op_callbacks(self, value):
self._thread_local_data.invoking_op_callbacks = value
def _initialize_physical_devices(self):
def _initialize_physical_devices(self, reinitialize=False):
"""Get local devices visible to the system."""
# We lazy initialize self._physical_devices since we do not want to do this
# the constructor since the backend may not be initialized yet.
with self._device_lock:
if self._physical_devices is not None:
if reinitialize is False and self._physical_devices is not None:
return
devs = pywrap_tfe.TF_ListPhysicalDevices()
@ -1270,26 +1270,9 @@ class Context(object):
def reinitialize_physical_devices(self):
"""Get local devices visible to the system."""
# We lazy initialize self._physical_devices since we do not want to do this
# the constructor since the backend may not be initialized yet.
with self._device_lock:
devs = pywrap_tfe.TF_ListPhysicalDevices()
self._physical_devices = [
PhysicalDevice(name=d.decode(),
device_type=d.decode().split(":")[1]) for d in devs]
self._physical_device_to_index = {
p: i for i, p in enumerate(self._physical_devices)
}
self._visible_device_list = list(self._physical_devices)
self._memory_growth_map = {
d: None for d in self._physical_devices if d.device_type == "GPU"
}
# Import device settings that may have been passed into the constructor
self._import_config()
# Reinitialize the physical device list after registering
# the pluggable device.
self._initialize_physical_devices(True)
def list_physical_devices(self, device_type=None):
"""List local devices visible to the system.