From 4d9db8b8091f6e4565b5e9b3a86dc98791644558 Mon Sep 17 00:00:00 2001 From: Zhoulong Jiang Date: Fri, 23 Oct 2020 15:23:53 +0000 Subject: [PATCH] address comments --- tensorflow/c/c_api_experimental.cc | 35 +++++++++++++++---- tensorflow/c/c_api_experimental.h | 2 +- .../c/experimental/stream_executor/BUILD | 8 ----- .../c/experimental/stream_executor/test/BUILD | 16 +++------ tensorflow/c/kernels.h | 7 ++-- tensorflow/c/test_pluggable_device.cc | 24 ------------- tensorflow/core/framework/load_library.cc | 31 ---------------- tensorflow/python/eager/context.py | 29 ++++----------- 8 files changed, 43 insertions(+), 109 deletions(-) delete mode 100644 tensorflow/c/test_pluggable_device.cc diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 6569ee10074..2f3dba8a9ce 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -37,7 +37,9 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/strcat.h" @@ -747,6 +749,14 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints( opts->opts.validate_colocation_constraints = enable; } +// Load a Pluggable Device library +// On sucess, returns the handle to library in result and return OK from the +// function. Otherwise return nullptr in result and error Status from the +// function. +// +// If `library_filename` has already been loaded, we return a cached handle. +// Device and Kernels/Ops are registered as globals when a library is loaded +// for the first time. TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, TF_Status* status) { #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) @@ -755,17 +765,28 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, return nullptr; #else TF_Library* lib_handle = new TF_Library; - status->status = tensorflow::LoadPluggableDeviceLibrary( - library_filename, &lib_handle->lib_handle); - if (!status->status.ok()) { - delete lib_handle; - return nullptr; + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_map loaded_libs; + tensorflow::Env* env = tensorflow::Env::Default(); + { + tensorflow::mutex_lock lock(mu); + if (loaded_libs.find(library_filename) != loaded_libs.end()) { + lib_handle->lib_handle = loaded_libs[library_filename]; + } else { + status->status = + env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle); + if (status->status.ok()) { + // Init PluggableDevice Plugin + } else { + delete lib_handle; + return nullptr; + } + } + return lib_handle; } - return lib_handle; #endif } void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { - if (lib_handle == nullptr) return; delete lib_handle; } diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6199d2a3822..317e2bf4a91 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -308,7 +308,7 @@ TF_ImportGraphDefOptionsSetValidateColocationConstraints( // device and related kernels present in that library. // // Pass "library_filename" to a platform-specific mechanism for dynamically -// loading a library. The reles for determining the exact location of the +// loading a library. The rules for determining the exact location of the // library are platform-specific and are not documented here // // On success, place OK in status and return the newly created library handle. diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 0334e0896f6..1a421dca51e 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -60,11 +60,3 @@ tf_cc_test( "//tensorflow/stream_executor:stream_executor_pimpl", ], ) - -exports_files( - srcs = [ - "stream_executor.h", - "stream_executor_internal.h", - ], - visibility = ["//visibility:public"], -) diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD index 45a5f5613c3..e8c1ae41eee 100644 --- a/tensorflow/c/experimental/stream_executor/test/BUILD +++ b/tensorflow/c/experimental/stream_executor/test/BUILD @@ -2,20 +2,12 @@ # test for stream_executor load( "//tensorflow:tensorflow.bzl", - "tf_custom_op_library", + "tf_cc_shared_object", ) -cc_library( - name = "test_pluggable_device_impl", - srcs = ["test_pluggable_device.cc"], - hdrs = [ - "//tensorflow/c:headers", - "//tensorflow/c/experimental/stream_executor:stream_executor.h", - ], -) - -tf_custom_op_library( +tf_cc_shared_object( name = "test_pluggable_device.so", + srcs = ["test_pluggable_device.cc"], visibility = ["//tensorflow/c:__subpackages__"], - deps = [":test_pluggable_device_impl"], + deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"], ) diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 6905907e808..de3369c81cc 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -134,12 +134,13 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); // -------------------------------------------------------------------------- // OpKernelContext routines -// TF_GetStream returns the SP_Stream available in ctx +// TF_GetStream returns the SP_Stream available in ctx. // This function returns a stream only for devices registered using the // StreamExecutor C API // (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return -// nullptr in all other cases. Experimental: this function doesn't have -// compatibility guarantees and subject to change at any time." +// nullptr in all other cases. +// Experimental: this function doesn't have compatibility guarantees and subject +// to change at any time." TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx); // TF_NumInputs returns the number of inputs available in ctx. diff --git a/tensorflow/c/test_pluggable_device.cc b/tensorflow/c/test_pluggable_device.cc deleted file mode 100644 index 8151aba6b65..00000000000 --- a/tensorflow/c/test_pluggable_device.cc +++ /dev/null @@ -1,24 +0,0 @@ - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/experimental/stream_executor/stream_executor.h" - -void SE_InitPlugin(SE_PlatformRegistrationParams* const params, - TF_Status* const status) { - params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE; - params->platform->name = "GPU"; - params->platform->type = "XGPU"; -} diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index 60fff4b5aa9..34cd4b3386b 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -101,35 +101,4 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, return Status::OK(); } -// Load a Pluggable Device library -// On sucess, returns the handle to library in result and return OK from the -// function. Otherwise return nullptr in result and error Status from the -// function. -// -// If `library_filename` has already been loaded, we return a cached handle. -// Device and Kernels/Ops are registered as globals when a library is laoded -// for the first time. -Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) { - static mutex mu(LINKER_INITIALIZED); - static std::unordered_map loaded_libs; - Env* env = Env::Default(); - Library library; - { - mutex_lock lock(mu); - if (loaded_libs.find(library_filename) != loaded_libs.end()) { - library = loaded_libs[library_filename]; - } else { - Status s = env->LoadDynamicLibrary(library_filename, &library.handle); - if (s.ok()) { - // Init PluggableDevice Plugin - } else { - return s; - } - - loaded_libs[library_filename] = library; - } - *result = library.handle; - return Status::OK(); - } -} } // namespace tensorflow diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index d3113598949..c412b338b58 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1244,12 +1244,12 @@ class Context(object): def invoking_op_callbacks(self, value): self._thread_local_data.invoking_op_callbacks = value - def _initialize_physical_devices(self): + def _initialize_physical_devices(self, reinitialize=False): """Get local devices visible to the system.""" # We lazy initialize self._physical_devices since we do not want to do this # the constructor since the backend may not be initialized yet. with self._device_lock: - if self._physical_devices is not None: + if reinitialize is False and self._physical_devices is not None: return devs = pywrap_tfe.TF_ListPhysicalDevices() @@ -1267,29 +1267,12 @@ class Context(object): # Import device settings that may have been passed into the constructor self._import_config() - + def reinitialize_physical_devices(self): """Get local devices visible to the system.""" - # We lazy initialize self._physical_devices since we do not want to do this - # the constructor since the backend may not be initialized yet. - with self._device_lock: - devs = pywrap_tfe.TF_ListPhysicalDevices() - self._physical_devices = [ - PhysicalDevice(name=d.decode(), - device_type=d.decode().split(":")[1]) for d in devs] - self._physical_device_to_index = { - p: i for i, p in enumerate(self._physical_devices) - } - - self._visible_device_list = list(self._physical_devices) - self._memory_growth_map = { - d: None for d in self._physical_devices if d.device_type == "GPU" - } - - # Import device settings that may have been passed into the constructor - self._import_config() - - + # Reinitialize the physical device list after registering + # the pluggable device. + self._initialize_physical_devices(True) def list_physical_devices(self, device_type=None): """List local devices visible to the system.