address comments
This commit is contained in:
parent
554a5fbd72
commit
4d9db8b809
@ -37,7 +37,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/platform/blocking_counter.h"
|
#include "tensorflow/core/platform/blocking_counter.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/net.h"
|
#include "tensorflow/core/platform/net.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
@ -747,6 +749,14 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
|||||||
opts->opts.validate_colocation_constraints = enable;
|
opts->opts.validate_colocation_constraints = enable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load a Pluggable Device library
|
||||||
|
// On sucess, returns the handle to library in result and return OK from the
|
||||||
|
// function. Otherwise return nullptr in result and error Status from the
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// If `library_filename` has already been loaded, we return a cached handle.
|
||||||
|
// Device and Kernels/Ops are registered as globals when a library is loaded
|
||||||
|
// for the first time.
|
||||||
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||||
@ -755,17 +765,28 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
#else
|
#else
|
||||||
TF_Library* lib_handle = new TF_Library;
|
TF_Library* lib_handle = new TF_Library;
|
||||||
status->status = tensorflow::LoadPluggableDeviceLibrary(
|
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||||
library_filename, &lib_handle->lib_handle);
|
static std::unordered_map<std::string, void*> loaded_libs;
|
||||||
if (!status->status.ok()) {
|
tensorflow::Env* env = tensorflow::Env::Default();
|
||||||
delete lib_handle;
|
{
|
||||||
return nullptr;
|
tensorflow::mutex_lock lock(mu);
|
||||||
|
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
|
||||||
|
lib_handle->lib_handle = loaded_libs[library_filename];
|
||||||
|
} else {
|
||||||
|
status->status =
|
||||||
|
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||||
|
if (status->status.ok()) {
|
||||||
|
// Init PluggableDevice Plugin
|
||||||
|
} else {
|
||||||
|
delete lib_handle;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lib_handle;
|
||||||
}
|
}
|
||||||
return lib_handle;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
|
||||||
if (lib_handle == nullptr) return;
|
|
||||||
delete lib_handle;
|
delete lib_handle;
|
||||||
}
|
}
|
||||||
|
@ -308,7 +308,7 @@ TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
|||||||
// device and related kernels present in that library.
|
// device and related kernels present in that library.
|
||||||
//
|
//
|
||||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||||
// loading a library. The reles for determining the exact location of the
|
// loading a library. The rules for determining the exact location of the
|
||||||
// library are platform-specific and are not documented here
|
// library are platform-specific and are not documented here
|
||||||
//
|
//
|
||||||
// On success, place OK in status and return the newly created library handle.
|
// On success, place OK in status and return the newly created library handle.
|
||||||
|
@ -60,11 +60,3 @@ tf_cc_test(
|
|||||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files(
|
|
||||||
srcs = [
|
|
||||||
"stream_executor.h",
|
|
||||||
"stream_executor_internal.h",
|
|
||||||
],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
@ -2,20 +2,12 @@
|
|||||||
# test for stream_executor
|
# test for stream_executor
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"tf_custom_op_library",
|
"tf_cc_shared_object",
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
tf_cc_shared_object(
|
||||||
name = "test_pluggable_device_impl",
|
|
||||||
srcs = ["test_pluggable_device.cc"],
|
|
||||||
hdrs = [
|
|
||||||
"//tensorflow/c:headers",
|
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor.h",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_custom_op_library(
|
|
||||||
name = "test_pluggable_device.so",
|
name = "test_pluggable_device.so",
|
||||||
|
srcs = ["test_pluggable_device.cc"],
|
||||||
visibility = ["//tensorflow/c:__subpackages__"],
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
deps = [":test_pluggable_device_impl"],
|
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||||
)
|
)
|
||||||
|
@ -134,12 +134,13 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
|||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// OpKernelContext routines
|
// OpKernelContext routines
|
||||||
|
|
||||||
// TF_GetStream returns the SP_Stream available in ctx
|
// TF_GetStream returns the SP_Stream available in ctx.
|
||||||
// This function returns a stream only for devices registered using the
|
// This function returns a stream only for devices registered using the
|
||||||
// StreamExecutor C API
|
// StreamExecutor C API
|
||||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
||||||
// nullptr in all other cases. Experimental: this function doesn't have
|
// nullptr in all other cases.
|
||||||
// compatibility guarantees and subject to change at any time."
|
// Experimental: this function doesn't have compatibility guarantees and subject
|
||||||
|
// to change at any time."
|
||||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
|
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
// TF_NumInputs returns the number of inputs available in ctx.
|
// TF_NumInputs returns the number of inputs available in ctx.
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
|
|
||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
|
||||||
|
|
||||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
|
||||||
TF_Status* const status) {
|
|
||||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
|
||||||
params->platform->name = "GPU";
|
|
||||||
params->platform->type = "XGPU";
|
|
||||||
}
|
|
@ -101,35 +101,4 @@ Status LoadDynamicLibrary(const char* library_filename, void** result,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load a Pluggable Device library
|
|
||||||
// On sucess, returns the handle to library in result and return OK from the
|
|
||||||
// function. Otherwise return nullptr in result and error Status from the
|
|
||||||
// function.
|
|
||||||
//
|
|
||||||
// If `library_filename` has already been loaded, we return a cached handle.
|
|
||||||
// Device and Kernels/Ops are registered as globals when a library is laoded
|
|
||||||
// for the first time.
|
|
||||||
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) {
|
|
||||||
static mutex mu(LINKER_INITIALIZED);
|
|
||||||
static std::unordered_map<string, Library> loaded_libs;
|
|
||||||
Env* env = Env::Default();
|
|
||||||
Library library;
|
|
||||||
{
|
|
||||||
mutex_lock lock(mu);
|
|
||||||
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
|
|
||||||
library = loaded_libs[library_filename];
|
|
||||||
} else {
|
|
||||||
Status s = env->LoadDynamicLibrary(library_filename, &library.handle);
|
|
||||||
if (s.ok()) {
|
|
||||||
// Init PluggableDevice Plugin
|
|
||||||
} else {
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
loaded_libs[library_filename] = library;
|
|
||||||
}
|
|
||||||
*result = library.handle;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1244,12 +1244,12 @@ class Context(object):
|
|||||||
def invoking_op_callbacks(self, value):
|
def invoking_op_callbacks(self, value):
|
||||||
self._thread_local_data.invoking_op_callbacks = value
|
self._thread_local_data.invoking_op_callbacks = value
|
||||||
|
|
||||||
def _initialize_physical_devices(self):
|
def _initialize_physical_devices(self, reinitialize=False):
|
||||||
"""Get local devices visible to the system."""
|
"""Get local devices visible to the system."""
|
||||||
# We lazy initialize self._physical_devices since we do not want to do this
|
# We lazy initialize self._physical_devices since we do not want to do this
|
||||||
# the constructor since the backend may not be initialized yet.
|
# the constructor since the backend may not be initialized yet.
|
||||||
with self._device_lock:
|
with self._device_lock:
|
||||||
if self._physical_devices is not None:
|
if reinitialize is False and self._physical_devices is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
||||||
@ -1270,26 +1270,9 @@ class Context(object):
|
|||||||
|
|
||||||
def reinitialize_physical_devices(self):
|
def reinitialize_physical_devices(self):
|
||||||
"""Get local devices visible to the system."""
|
"""Get local devices visible to the system."""
|
||||||
# We lazy initialize self._physical_devices since we do not want to do this
|
# Reinitialize the physical device list after registering
|
||||||
# the constructor since the backend may not be initialized yet.
|
# the pluggable device.
|
||||||
with self._device_lock:
|
self._initialize_physical_devices(True)
|
||||||
devs = pywrap_tfe.TF_ListPhysicalDevices()
|
|
||||||
self._physical_devices = [
|
|
||||||
PhysicalDevice(name=d.decode(),
|
|
||||||
device_type=d.decode().split(":")[1]) for d in devs]
|
|
||||||
self._physical_device_to_index = {
|
|
||||||
p: i for i, p in enumerate(self._physical_devices)
|
|
||||||
}
|
|
||||||
|
|
||||||
self._visible_device_list = list(self._physical_devices)
|
|
||||||
self._memory_growth_map = {
|
|
||||||
d: None for d in self._physical_devices if d.device_type == "GPU"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Import device settings that may have been passed into the constructor
|
|
||||||
self._import_config()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def list_physical_devices(self, device_type=None):
|
def list_physical_devices(self, device_type=None):
|
||||||
"""List local devices visible to the system.
|
"""List local devices visible to the system.
|
||||||
|
Loading…
Reference in New Issue
Block a user