diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 99a278a14a4..4bfc79d9939 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -145,6 +145,8 @@ if _running_from_pip_package(): _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) + # Load Pluggable Device Library + _ll.load_pluggable_device_library(_plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index ae82f7b4792..e69287afd46 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -155,6 +155,8 @@ if _running_from_pip_package(): _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) + # Load Pluggable Device Library + _ll.load_pluggable_device_library(_plugin_dir) # Delete modules that should be hidden from dir(). # Don't fail if these modules are not available. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 16f6b860308..43119d2b513 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -492,6 +492,8 @@ tf_cuda_library( ], hdrs = [ "kernels.h", + "//tensorflow/c/experimental/stream_executor:stream_executor.h", + "//tensorflow/c/experimental/stream_executor:stream_executor_internal.h", ], copts = tf_copts(), visibility = ["//visibility:public"], @@ -499,6 +501,7 @@ tf_cuda_library( ":tf_status", ":tf_status_helper", ":tf_tensor_internal", + "//tensorflow/c/experimental/stream_executor:stream_executor", ] + select({ "//tensorflow:android": [ ":c_api_internal", @@ -578,6 +581,7 @@ tf_cuda_cc_test( srcs = ["c_api_test.cc"], data = [ ":test_op1.so", + ":test_pluggable_device.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], linkopts = select({ @@ -688,6 +692,19 @@ tf_custom_op_library( srcs = ["test_op1.cc"], ) +cc_library( + name = "test_pluggable_device_impl", + srcs = ["test_pluggable_device.cc"], + hdrs = ["c_api_macros.h", + "tf_status.h", + "//tensorflow/c/experimental/stream_executor:stream_executor.h",], +) + +tf_custom_op_library( + name = "test_pluggable_device.so", + deps = [":test_pluggable_device_impl"], +) + tf_kernel_library( name = "test_op_kernel", srcs = ["test_op.cc"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 2e1759ecea0..f88cbe63c3b 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -308,6 +308,9 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, Status LoadDynamicLibrary(const char* library_filename, void** result, const void** buf, size_t* len); +// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file) +Status LoadPluggableDeviceLibrary(const char* library_filename, void** result); + // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). @@ -581,6 +584,23 @@ TF_Buffer* TF_GetAllOpList() { return ret; } +TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, + TF_Status* status) { + TF_Library* lib_handle = new TF_Library; + status->status = tensorflow::LoadPluggableDeviceLibrary( + library_filename, &lib_handle->lib_handle); + if (!status->status.ok()) { + delete lib_handle; + return nullptr; + } + return lib_handle; +} + +void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { + if (lib_handle == nullptr) return; + delete lib_handle; +} + // -------------------------------------------------------------------------- // ListDevices & SessionListDevices API diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 0b4d9993e4d..9113f70ce0c 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1449,7 +1449,6 @@ typedef struct TF_Library TF_Library; // On failure, place an error status in status and return NULL. TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status); - // Get the OpList of OpDefs defined in the library pointed by lib_handle. // // Returns a TF_Buffer. The memory pointed to by the result is owned by @@ -1461,6 +1460,24 @@ TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); // Does NOT unload the library. TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); +// Load the library specified by library_filename and register the pluggable +// device and related kernels present in that library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The reles for determining the exact location of the +// library are platform-specific and are not documented here +// +// On success, place OK in status and return the newly created library handle. +// The caller owns the library handle +// +// On failure, place an error status in status and return NULL. +TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary( + const char* library_filename, TF_Status* status); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle( + TF_Library* lib_handle); // Get the OpList of all OpDefs defined in this address space. // Returns a TF_Buffer, ownership of which is transferred to the caller // (and can be freed using TF_DeleteBuffer). diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bbbbb8f7d56..d102015b643 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -228,6 +228,21 @@ TEST(CAPI, LibraryLoadFunctions) { } } +TEST(CAPI, LibraryPluggableDeviceLoadFunctions) { +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + // Load the library. + TF_Status* status = TF_NewStatus(); + string lib_path = tensorflow::GetDataDependencyFilepath( + tensorflow::io::JoinPath("tensorflow", "c", "test_pluggable_device.so")); + TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + TF_DeletePluggableDeviceLibraryHandle(lib); +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) +} + void TestEncodeDecode(int line, const std::vector& data) { const tensorflow::int64 n = data.size(); Status status; @@ -1467,6 +1482,7 @@ TEST(CAPI, DeletingNullPointerIsSafe) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteDeviceList(nullptr); TF_DeleteLibraryHandle(nullptr); + TF_DeletePluggableDeviceLibraryHandle(nullptr); TF_DeleteApiDefMap(nullptr); TF_DeleteStatus(status); diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 95bb12e8e50..f0dfacf0f85 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -59,3 +59,12 @@ tf_cc_test( "//tensorflow/stream_executor:stream_executor_pimpl", ], ) + +exports_files( + srcs = [ + "stream_executor.h", + "stream_executor_internal.h", + ], + visibility = ["//visibility:public"], +) + diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 09442a4f7b7..093499c298e 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" -#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" @@ -158,41 +157,6 @@ port::Status ValidateSEPlatformRegistrationParams( } #undef VALIDATE_MEMBER -struct TFStatusDeleter { - void operator()(TF_Status* s) const { TF_DeleteStatus(s); } -}; -using OwnedTFStatus = std::unique_ptr; - -class CStream : public internal::StreamInterface { - public: - CStream(SP_Device* device, SP_StreamExecutor* stream_executor) - : device_(device), - stream_executor_(stream_executor), - stream_handle_(nullptr) {} - ~CStream() override { Destroy(); } - - port::Status Create() { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); - port::Status s = StatusFromTF_Status(c_status.get()); - return s; - } - - void Destroy() { - if (stream_handle_ != nullptr) { - stream_executor_->destroy_stream(device_, stream_handle_); - stream_handle_ = nullptr; - } - } - - SP_Stream Handle() { return stream_handle_; } - - private: - SP_Device* device_; - SP_StreamExecutor* stream_executor_; - SP_Stream stream_handle_; -}; - // Converts SE_EventStatus to Event::Status. Event::Status SEEventStatusToEventStatus(SE_EventStatus s) { switch (s) { @@ -207,82 +171,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) { } } -class CEvent : public internal::EventInterface { - public: - CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) - : device_(device), - stream_executor_(stream_executor), - event_handle_(nullptr) {} - ~CEvent() override { Destroy(); } - - port::Status Create() { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->create_event(device_, &event_handle_, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - - port::Status Record(SP_Stream stream_handle) { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->record_event(device_, stream_handle, event_handle_, - c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - - void Destroy() { - if (event_handle_ != nullptr) { - stream_executor_->destroy_event(device_, event_handle_); - event_handle_ = nullptr; - } - } - - SP_Event Handle() { return event_handle_; } - - private: - SP_Device* device_; - SP_StreamExecutor* stream_executor_; - SP_Event event_handle_; -}; - -class CTimer : public internal::TimerInterface { - public: - CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, - SP_TimerFns* timer_fns) - : device_(device), - stream_executor_(stream_executor), - timer_handle_(nullptr), - timer_fns_(timer_fns) {} - ~CTimer() override { Destroy(); } - - port::Status Create() { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - - void Destroy() { - if (timer_handle_ != nullptr) { - stream_executor_->destroy_timer(device_, timer_handle_); - timer_handle_ = nullptr; - } - } - - SP_Timer Handle() { return timer_handle_; } - - uint64 Microseconds() const override { - return timer_fns_->nanoseconds(timer_handle_) / 1000; - } - - uint64 Nanoseconds() const override { - return timer_fns_->nanoseconds(timer_handle_); - } - - private: - SP_Device* device_; - SP_StreamExecutor* stream_executor_; - SP_Timer timer_handle_; - SP_TimerFns* timer_fns_; -}; - // Converts DeviceMemoryBase to a C struct. SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 52ae4ba77e0..bebf6069829 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -19,10 +19,12 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ #include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/platform.h" +using tensorflow::StatusFromTF_Status; namespace stream_executor { // Plugin initialization function that a device plugin @@ -37,6 +39,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle); // testing). port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn); +namespace { +struct TFStatusDeleter { + void operator()(TF_Status* s) const { TF_DeleteStatus(s); } +}; +using OwnedTFStatus = std::unique_ptr; +} // namespace + class CPlatform : public Platform { public: explicit CPlatform(SP_Platform platform, @@ -83,5 +92,111 @@ class CPlatform : public Platform { stream_executor::ExecutorCache executor_cache_; }; +class CStream : public internal::StreamInterface { + public: + CStream(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + stream_handle_(nullptr) {} + ~CStream() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + return s; + } + + void Destroy() { + if (stream_handle_ != nullptr) { + stream_executor_->destroy_stream(device_, stream_handle_); + stream_handle_ = nullptr; + } + } + + SP_Stream Handle() { return stream_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Stream stream_handle_; +}; + +class CEvent : public internal::EventInterface { + public: + CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + event_handle_(nullptr) {} + ~CEvent() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_event(device_, &event_handle_, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + port::Status Record(SP_Stream stream_handle) { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->record_event(device_, stream_handle, event_handle_, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (event_handle_ != nullptr) { + stream_executor_->destroy_event(device_, event_handle_); + event_handle_ = nullptr; + } + } + + SP_Event Handle() { return event_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Event event_handle_; +}; + +class CTimer : public internal::TimerInterface { + public: + CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, + SP_TimerFns* timer_fns) + : device_(device), + stream_executor_(stream_executor), + timer_handle_(nullptr), + timer_fns_(timer_fns) {} + ~CTimer() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (timer_handle_ != nullptr) { + stream_executor_->destroy_timer(device_, timer_handle_); + timer_handle_ = nullptr; + } + } + + SP_Timer Handle() { return timer_handle_; } + + uint64 Microseconds() const override { + return timer_fns_->nanoseconds(timer_handle_) / 1000; + } + + uint64 Nanoseconds() const override { + return timer_fns_->nanoseconds(timer_handle_); + } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Timer timer_handle_; + SP_TimerFns* timer_fns_; +}; + } // namespace stream_executor #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index ed501b5b101..878b920b885 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream.h" // This file forms the basis of a stable ABI for third-party kernel // implementations. It is crucial that changes to this file are made cautiously @@ -168,6 +170,15 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, TF_SetStatus(status, TF_OK, ""); } +// This function is only for pluggable device +// it will return nullptr in all other cases +SP_Stream TF_GetStream(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + auto c_stream = static_cast( + cc_ctx->op_device_context()->stream()->implementation()); + return c_stream->Handle(); +} + int TF_NumInputs(TF_OpKernelContext* ctx) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); return cc_ctx->num_inputs(); diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 489aa5399a5..f5214c0e893 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" @@ -65,6 +66,10 @@ typedef struct TF_KernelBuilder TF_KernelBuilder; typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; typedef struct TF_OpKernelContext TF_OpKernelContext; +// TF_InitKernel to do op/kernel registration. +// Plugin needs to implement this function to register all kernels. +void TF_InitKernel(); + // Allocates a new kernel builder and returns a pointer to it. // // If non-null, TensorFlow will call create_func when it needs to instantiate @@ -128,6 +133,11 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); // -------------------------------------------------------------------------- // OpKernelContext routines +// TF_GetStream returns the SP_Stream available in ctx +// This function is only for pluggable device +// it will return nullptr in all other cased. +TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx); + // TF_NumInputs returns the number of inputs available in ctx. TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); diff --git a/tensorflow/c/test_pluggable_device.cc b/tensorflow/c/test_pluggable_device.cc new file mode 100644 index 00000000000..8151aba6b65 --- /dev/null +++ b/tensorflow/c/test_pluggable_device.cc @@ -0,0 +1,24 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +void SE_InitPlugin(SE_PlatformRegistrationParams* const params, + TF_Status* const status) { + params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE; + params->platform->name = "GPU"; + params->platform->type = "XGPU"; +} diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index 34cd4b3386b..60fff4b5aa9 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -101,4 +101,35 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, return Status::OK(); } +// Load a Pluggable Device library +// On sucess, returns the handle to library in result and return OK from the +// function. Otherwise return nullptr in result and error Status from the +// function. +// +// If `library_filename` has already been loaded, we return a cached handle. +// Device and Kernels/Ops are registered as globals when a library is laoded +// for the first time. +Status LoadPluggableDeviceLibrary(const char* library_filename, void** result) { + static mutex mu(LINKER_INITIALIZED); + static std::unordered_map loaded_libs; + Env* env = Env::Default(); + Library library; + { + mutex_lock lock(mu); + if (loaded_libs.find(library_filename) != loaded_libs.end()) { + library = loaded_libs[library_filename]; + } else { + Status s = env->LoadDynamicLibrary(library_filename, &library.handle); + if (s.ok()) { + // Init PluggableDevice Plugin + } else { + return s; + } + + loaded_libs[library_filename] = library; + } + *result = library.handle; + return Status::OK(); + } +} } // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 178359e9487..9471bc99dd6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5976,11 +5976,10 @@ py_library( name = "pywrap_tensorflow", srcs = [ "pywrap_tensorflow.py", - ] + if_static( - ["pywrap_dlopen_global_flags.py"], - # Import will fail, indicating no global dlopen flags - otherwise = [], - ), # b/153585257 + # modular TF need this file to load expose C API symbols + # in pywrap_tensorflow_internal.so + "pywrap_dlopen_global_flags.py", + ], srcs_version = "PY2AND3", deps = [":pywrap_tensorflow_internal"], ) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 22b4884dd71..6efba380ca0 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -36,9 +36,9 @@ import traceback # go/tf-wildcard-import # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top -from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python.eager import context +from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow # pylint: enable=wildcard-import diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index ac656d322c4..55f0debadcb 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -711,6 +711,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { }, py::return_value_policy::reference); + m.def( + "TF_LoadPluggableDeviceLibrary", + [](const char* library_filename) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + auto output = + TF_LoadPluggableDeviceLibrary(library_filename, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }, + py::return_value_policy::reference); + m.def("TF_GetOpList", [](TF_Library* lib_handle) { TF_Buffer output_buffer = TF_GetOpList(lib_handle); return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize( @@ -720,6 +732,11 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle, py::call_guard()); + + m.def("TF_PluggableDeviceLibraryHandle", + TF_DeletePluggableDeviceLibraryHandle, + py::call_guard()); + m.def("TF_AddControlInput", TF_AddControlInput); m.def( "TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) { diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index d917f8a4b4e..d3113598949 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1268,6 +1268,29 @@ class Context(object): # Import device settings that may have been passed into the constructor self._import_config() + def reinitialize_physical_devices(self): + """Get local devices visible to the system.""" + # We lazy initialize self._physical_devices since we do not want to do this + # the constructor since the backend may not be initialized yet. + with self._device_lock: + devs = pywrap_tfe.TF_ListPhysicalDevices() + self._physical_devices = [ + PhysicalDevice(name=d.decode(), + device_type=d.decode().split(":")[1]) for d in devs] + self._physical_device_to_index = { + p: i for i, p in enumerate(self._physical_devices) + } + + self._visible_device_list = list(self._physical_devices) + self._memory_growth_map = { + d: None for d in self._physical_devices if d.device_type == "GPU" + } + + # Import device settings that may have been passed into the constructor + self._import_config() + + + def list_physical_devices(self, device_type=None): """List local devices visible to the system. diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index f37b48e76c2..4e4fb7ced90 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -29,7 +29,7 @@ from tensorflow.python import _pywrap_python_op_gen from tensorflow.python.client import pywrap_tf_session as py_tf from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export - +from tensorflow.python.eager import context @tf_export('load_op_library') def load_op_library(library_filename): @@ -157,3 +157,44 @@ def load_library(library_location): errno.ENOENT, 'The file or folder to load kernel libraries from does not exist.', library_location) + +@tf_export('load_pluggable_device_library') +def load_pluggable_device_library(library_location): + """Loads a Tensorflow PluggableDevice plugin + "library_location" can be a path to a specific shared object, or a folder. + If it is a folder, all shared objects will be loaded. when the library is + loaded, devices/kernels registered in the library via StreamExecutor C API + and Kernel/Op Registration C API are made available in TensorFlow process + + Args: + library_location: Path to the plugin or folder of plugins. + Relative or absolute filesystem path to a dynamic library file or folder. + + Returns: + None + + Raises: + OSError: When the file to be loaded is not found. + RuntimeError: when unable to load the library. + """ + if file_io.file_exists(library_location): + if file_io.is_directory(library_location): + directory_contents = file_io.list_directory(library_location) + + pluggable_device_libraries = [ + os.path.join(library_location, f) for f in directory_contents + if _is_shared_object(f)] + else: + pluggable_device_libraries = [library_location] + + for lib in pluggable_device_libraries: + py_tf.TF_LoadPluggableDeviceLibrary(lib) + # Reinitialized physical devices list after plugin registration + context.context().reinitialize_physical_devices() + else: + raise OSError( + errno.ENOENT, + 'The file or folder to load pluggable device libraries from does not exist.', + library_location) + + diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index ba64d009908..2dd0cafe488 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1576,6 +1576,10 @@ tf_module { name: "load_op_library" argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "load_pluggable_device_library" + argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "local_variables" argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 83baba1b1ce..119bd3b04ad 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -772,6 +772,10 @@ tf_module { name: "load_op_library" argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "load_pluggable_device_library" + argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "logical_and" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "