address comments
This commit is contained in:
parent
554a5fbd72
commit
4d9db8b809
@ -37,7 +37,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
@ -747,6 +749,14 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
opts->opts.validate_colocation_constraints = enable;
|
||||
}
|
||||
|
||||
// Load a Pluggable Device library
|
||||
// On sucess, returns the handle to library in result and return OK from the
|
||||
// function. Otherwise return nullptr in result and error Status from the
|
||||
// function.
|
||||
//
|
||||
// If `library_filename` has already been loaded, we return a cached handle.
|
||||
// Device and Kernels/Ops are registered as globals when a library is loaded
|
||||
// for the first time.
|
||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
@ -755,17 +765,28 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||
return nullptr;
|
||||
#else
|
||||
TF_Library* lib_handle = new TF_Library;
|
||||
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
||||
library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static std::unordered_map<std::string, void*> loaded_libs;
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu);
|
||||
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
|
||||
lib_handle->lib_handle = loaded_libs[library_filename];
|
||||
} else {
|
||||
status->status =
|
||||
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||
if (status->status.ok()) {
|
||||
// Init PluggableDevice Plugin
|
||||
} else {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return lib_handle;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||
if (lib_handle == nullptr) return;
|
||||
delete lib_handle;
|
||||
}
|
||||
|
@ -308,7 +308,7 @@ TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
// device and related kernels present in that library.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The reles for determining the exact location of the
|
||||
// loading a library. The rules for determining the exact location of the
|
||||
// library are platform-specific and are not documented here
|
||||
//
|
||||
// On success, place OK in status and return the newly created library handle.
|
||||
|
@ -60,11 +60,3 @@ tf_cc_test(
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
srcs = [
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
@ -2,20 +2,12 @@
|
||||
# test for stream_executor
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_cc_shared_object",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_pluggable_device_impl",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
hdrs = [
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
tf_cc_shared_object(
|
||||
name = "test_pluggable_device.so",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [":test_pluggable_device_impl"],
|
||||
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||
)
|
||||
|
@ -134,12 +134,13 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||
// --------------------------------------------------------------------------
|
||||
// OpKernelContext routines
|
||||
|
||||
// TF_GetStream returns the SP_Stream available in ctx
|
||||
// TF_GetStream returns the SP_Stream available in ctx.
|
||||
// This function returns a stream only for devices registered using the
|
||||
// StreamExecutor C API
|
||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
||||
// nullptr in all other cases. Experimental: this function doesn't have
|
||||
// compatibility guarantees and subject to change at any time."
|
||||
// nullptr in all other cases.
|
||||
// Experimental: this function doesn't have compatibility guarantees and subject
|
||||
// to change at any time."
|
||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
|
||||
|
||||
// TF_NumInputs returns the number of inputs available in ctx.
|
||||
|
@ -1,24 +0,0 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) {
|
||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
params->platform->name = "GPU";
|
||||
params->platform->type = "XGPU";
|
||||
}
|
@ -101,35 +101,4 @@ Status LoadDynamicLibrary(const char* library_filename, void** result,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load a Pluggable Device library
|
||||
// On sucess, returns the handle to library in result and return OK from the
|
||||
// function. Otherwise return nullptr in result and error Status from the
|
||||
// function.
|
||||
//
|
||||
// If `library_filename` has already been loaded, we return a cached handle.
|
||||
// Device and Kernels/Ops are registered as globals when a library is laoded
|
||||
// for the first time.
|
||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) {
|
||||
static mutex mu(LINKER_INITIALIZED);
|
||||
static std::unordered_map<string, Library> loaded_libs;
|
||||
Env* env = Env::Default();
|
||||
Library library;
|
||||
{
|
||||
mutex_lock lock(mu);
|
||||
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
|
||||
library = loaded_libs[library_filename];
|
||||
} else {
|
||||
Status s = env->LoadDynamicLibrary(library_filename, &library.handle);
|
||||
if (s.ok()) {
|
||||
// Init PluggableDevice Plugin
|
||||
} else {
|
||||
return s;
|
||||
}
|
||||
|
||||
loaded_libs[library_filename] = library;
|
||||
}
|
||||
*result = library.handle;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -1244,12 +1244,12 @@ class Context(object):
|
||||
def invoking_op_callbacks(self, value):
|
||||
self._thread_local_data.invoking_op_callbacks = value
|
||||
|
||||
def _initialize_physical_devices(self):
|
||||
def _initialize_physical_devices(self, reinitialize=False):
|
||||
"""Get local devices visible to the system."""
|
||||
# We lazy initialize self._physical_devices since we do not want to do this
|
||||
# the constructor since the backend may not be initialized yet.
|
||||
with self._device_lock:
|
||||
if self._physical_devices is not None:
|
||||
if reinitialize is False and self._physical_devices is not None:
|
||||
return
|
||||
|
||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||
@ -1270,26 +1270,9 @@ class Context(object):
|
||||
|
||||
def reinitialize_physical_devices(self):
|
||||
"""Get local devices visible to the system."""
|
||||
# We lazy initialize self._physical_devices since we do not want to do this
|
||||
# the constructor since the backend may not be initialized yet.
|
||||
with self._device_lock:
|
||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||
self._physical_devices = [
|
||||
PhysicalDevice(name=d.decode(),
|
||||
device_type=d.decode().split(":")[1]) for d in devs]
|
||||
self._physical_device_to_index = {
|
||||
p: i for i, p in enumerate(self._physical_devices)
|
||||
}
|
||||
|
||||
self._visible_device_list = list(self._physical_devices)
|
||||
self._memory_growth_map = {
|
||||
d: None for d in self._physical_devices if d.device_type == "GPU"
|
||||
}
|
||||
|
||||
# Import device settings that may have been passed into the constructor
|
||||
self._import_config()
|
||||
|
||||
|
||||
# Reinitialize the physical device list after registering
|
||||
# the pluggable device.
|
||||
self._initialize_physical_devices(True)
|
||||
|
||||
def list_physical_devices(self, device_type=None):
|
||||
"""List local devices visible to the system.
|
||||
|
Loading…
Reference in New Issue
Block a user