diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 0eae6ad17c0..c1179f8b1c6 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -246,6 +246,7 @@ cc_library( ":graph_info", ":memory_planner", ":minimal_logging", + ":shared_library", ":simple_memory_arena", ":string", ":tflite_with_xnnpack_optional", @@ -635,6 +636,13 @@ cc_test( ], ) +cc_library( + name = "shared_library", + hdrs = ["shared_library.h"], + copts = TFLITE_DEFAULT_COPTS, + linkopts = if_not_windows(["-ldl"]), +) + # Shared lib target for convenience, pulls in the core runtime and builtin ops. # Note: This target is not yet finalized, and the exact set of exported (C/C++) # APIs is subject to change. The output library name is platform dependent: diff --git a/tensorflow/lite/delegates/external/BUILD b/tensorflow/lite/delegates/external/BUILD index ca23f95122f..b1018935365 100644 --- a/tensorflow/lite/delegates/external/BUILD +++ b/tensorflow/lite/delegates/external/BUILD @@ -26,6 +26,7 @@ cc_library( hdrs = ["external_delegate.h"], deps = [ "//tensorflow/lite:minimal_logging", + "//tensorflow/lite:shared_library", "//tensorflow/lite/c:common", ], ) diff --git a/tensorflow/lite/delegates/external/external_delegate.cc b/tensorflow/lite/delegates/external/external_delegate.cc index 0ebfb62421c..02a8a1e3cfd 100644 --- a/tensorflow/lite/delegates/external/external_delegate.cc +++ b/tensorflow/lite/delegates/external/external_delegate.cc @@ -17,42 +17,12 @@ limitations under the License. #include #include -#if defined(_WIN32) -#include -#else -#include -#endif - #include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/shared_library.h" namespace tflite { namespace { -// Library Support construct to handle dynamic library operations -#if defined(_WIN32) -struct LibSupport { - static void* Load(const char* lib) { return LoadLibrary(lib); } - - static void* GetSymbol(void* handle, const char* symbol) { - return (void*)GetProcAddress((HMODULE)handle, symbol); - } - - static int UnLoad(void* handle) { return FreeLibrary((HMODULE)handle); } -}; -#else -struct LibSupport { - static void* Load(const char* lib) { - return dlopen(lib, RTLD_LAZY | RTLD_LOCAL); - } - - static void* GetSymbol(void* handle, const char* symbol) { - return dlsym(handle, symbol); - } - - static int UnLoad(void* handle) { return dlclose(handle); } -}; -#endif - // External delegate library construct struct ExternalLib { using CreateDelegatePtr = std::add_pointer( - LibSupport::GetSymbol(handle, "tflite_plugin_create_delegate")); - destroy = reinterpret_cast( - LibSupport::GetSymbol(handle, "tflite_plugin_destroy_delegate")); + create = + reinterpret_cast(SharedLibrary::GetLibrarySymbol( + handle, "tflite_plugin_create_delegate")); + destroy = + reinterpret_cast(SharedLibrary::GetLibrarySymbol( + handle, "tflite_plugin_destroy_delegate")); return create && destroy; } return false; diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 4b491d41881..996fc7e6b82 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -14,9 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/interpreter_builder.h" -#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32) -#include -#endif #include #include #include @@ -31,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/shared_library.h" #include "tensorflow/lite/tflite_with_xnnpack_optional.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -117,15 +115,22 @@ const char* kEmptyTensorName = ""; // For flex delegate, see also the strong override in // lite/delegates/flex/delegate.cc. TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { -#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32) +#if !defined(__ANDROID__) // If _pywrap_tensorflow_internal.so is available, use // TF_AcquireFlexDelegate() to initialize flex delegate. + const char* filename_pywrap_tensorflow_internal = +#if defined(_WIN32) + "_pywrap_tensorflow_internal.pyd"; +#else + "_pywrap_tensorflow_internal.so"; +#endif void* lib_tf_internal = - dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL); + SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal); if (lib_tf_internal) { auto TF_AcquireFlexDelegate = reinterpret_cast( - dlsym(lib_tf_internal, "TF_AcquireFlexDelegate")); + SharedLibrary::GetLibrarySymbol(lib_tf_internal, + "TF_AcquireFlexDelegate")); if (TF_AcquireFlexDelegate) { return TF_AcquireFlexDelegate(); } diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index b3799be7af9..427f54d0e2c 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -28,6 +28,7 @@ cc_library( ":python_error_reporter", ":python_utils", "//tensorflow/lite:framework", + "//tensorflow/lite:shared_library", "//tensorflow/lite:string_util", "//tensorflow/lite:util", "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 2a8c1ffdcd6..7295a46193e 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -14,13 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" -// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead. -#if defined(_WIN32) -#include -#else -#include -#endif // defined(_WIN32) - #include #include @@ -36,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/python/interpreter_wrapper/numpy.h" #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" +#include "tensorflow/lite/shared_library.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/util.h" @@ -154,25 +148,13 @@ bool RegisterCustomOpByName(const char* registerer_name, // Look for the Registerer function by name. RegistererFunctionType registerer = reinterpret_cast( - // We don't have dlsym on Windows, use GetProcAddress instead. -#if defined(_WIN32) - GetProcAddress(nullptr, registerer_name) -#else - dlsym(RTLD_DEFAULT, registerer_name) -#endif // defined(_WIN32) - ); + SharedLibrary::GetSymbol(registerer_name)); // Fail in an informative way if the function was not found. if (registerer == nullptr) { - // We don't have dlerror on Windows, use GetLastError instead. *error_msg = -#if defined(_WIN32) - absl::StrFormat("Looking up symbol '%s' failed with error (0x%x).", - registerer_name, GetLastError()); -#else absl::StrFormat("Looking up symbol '%s' failed with error '%s'.", - registerer_name, dlerror()); -#endif // defined(_WIN32) + registerer_name, SharedLibrary::GetError()); return false; } diff --git a/tensorflow/lite/shared_library.h b/tensorflow/lite/shared_library.h new file mode 100644 index 00000000000..7cf34a03125 --- /dev/null +++ b/tensorflow/lite/shared_library.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SHARED_LIBRARY_H_ +#define TENSORFLOW_LITE_SHARED_LIBRARY_H_ + +#if defined(_WIN32) +// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead. +#include +#else +#include +#endif // defined(_WIN32) + +namespace tflite { + +// SharedLibrary provides a uniform set of APIs across different platforms to +// handle dynamic library operations +class SharedLibrary { + public: +#if defined(_WIN32) + static inline void* LoadLibrary(const char* lib) { + return ::LoadLibrary(lib); + } + static inline void* GetLibrarySymbol(void* handle, const char* symbol) { + return static_cast( + GetProcAddress(static_cast(handle), symbol)); + } + static inline void* GetSymbol(const char* symbol) { + return static_cast(GetProcAddress(nullptr, symbol)); + } + static inline int UnLoadLibrary(void* handle) { + return FreeLibrary(static_cast(handle)); + } + static inline const char* GetError() { return "Unknown"; } +#else + static inline void* LoadLibrary(const char* lib) { + return dlopen(lib, RTLD_LAZY | RTLD_LOCAL); + } + static inline void* GetLibrarySymbol(void* handle, const char* symbol) { + return dlsym(handle, symbol); + } + static inline void* GetSymbol(const char* symbol) { + return dlsym(RTLD_DEFAULT, symbol); + } + static inline int UnLoadLibrary(void* handle) { return dlclose(handle); } + static inline const char* GetError() { return dlerror(); } +#endif // defined(_WIN32) +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_SHARED_LIBRARY_H_