tflite: Add SharedLibrary class

Introduce SharedLibrary class to handle shared library regardless of OS type.
Applied this class to flex delegate, external delegate and interpreter_wrapper.

PiperOrigin-RevId: 320307902
Change-Id: Iab7b56599eb20e56bd57d59d4c52a2ff7d571604
This commit is contained in:
Terry Heo 2020-07-08 18:36:44 -07:00 committed by TensorFlower Gardener
parent f5e22e5e0e
commit af54c4e33b
7 changed files with 95 additions and 63 deletions

View File

@ -246,6 +246,7 @@ cc_library(
":graph_info",
":memory_planner",
":minimal_logging",
":shared_library",
":simple_memory_arena",
":string",
":tflite_with_xnnpack_optional",
@ -635,6 +636,13 @@ cc_test(
],
)
cc_library(
name = "shared_library",
hdrs = ["shared_library.h"],
copts = TFLITE_DEFAULT_COPTS,
linkopts = if_not_windows(["-ldl"]),
)
# Shared lib target for convenience, pulls in the core runtime and builtin ops.
# Note: This target is not yet finalized, and the exact set of exported (C/C++)
# APIs is subject to change. The output library name is platform dependent:

View File

@ -26,6 +26,7 @@ cc_library(
hdrs = ["external_delegate.h"],
deps = [
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:shared_library",
"//tensorflow/lite/c:common",
],
)

View File

@ -17,42 +17,12 @@ limitations under the License.
#include <string>
#include <vector>
#if defined(_WIN32)
#include <Windows.h>
#else
#include <dlfcn.h>
#endif
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/shared_library.h"
namespace tflite {
namespace {
// Library Support construct to handle dynamic library operations
#if defined(_WIN32)
struct LibSupport {
static void* Load(const char* lib) { return LoadLibrary(lib); }
static void* GetSymbol(void* handle, const char* symbol) {
return (void*)GetProcAddress((HMODULE)handle, symbol);
}
static int UnLoad(void* handle) { return FreeLibrary((HMODULE)handle); }
};
#else
struct LibSupport {
static void* Load(const char* lib) {
return dlopen(lib, RTLD_LAZY | RTLD_LOCAL);
}
static void* GetSymbol(void* handle, const char* symbol) {
return dlsym(handle, symbol);
}
static int UnLoad(void* handle) { return dlclose(handle); }
};
#endif
// External delegate library construct
struct ExternalLib {
using CreateDelegatePtr = std::add_pointer<TfLiteDelegate*(
@ -62,15 +32,17 @@ struct ExternalLib {
// Open a given delegate library and load the create/destroy symbols
bool load(const std::string library) {
void* handle = LibSupport::Load(library.c_str());
void* handle = SharedLibrary::LoadLibrary(library.c_str());
if (handle == nullptr) {
TFLITE_LOG(TFLITE_LOG_INFO, "Unable to load external delegate from : %s",
library.c_str());
} else {
create = reinterpret_cast<decltype(create)>(
LibSupport::GetSymbol(handle, "tflite_plugin_create_delegate"));
destroy = reinterpret_cast<decltype(destroy)>(
LibSupport::GetSymbol(handle, "tflite_plugin_destroy_delegate"));
create =
reinterpret_cast<decltype(create)>(SharedLibrary::GetLibrarySymbol(
handle, "tflite_plugin_create_delegate"));
destroy =
reinterpret_cast<decltype(destroy)>(SharedLibrary::GetLibrarySymbol(
handle, "tflite_plugin_destroy_delegate"));
return create && destroy;
}
return false;

View File

@ -14,9 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter_builder.h"
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
#include <dlfcn.h>
#endif
#include <fcntl.h>
#include <stdint.h>
#include <stdio.h>
@ -31,6 +28,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/shared_library.h"
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/version.h"
@ -117,15 +115,22 @@ const char* kEmptyTensorName = "";
// For flex delegate, see also the strong override in
// lite/delegates/flex/delegate.cc.
TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
#if !defined(__ANDROID__)
// If _pywrap_tensorflow_internal.so is available, use
// TF_AcquireFlexDelegate() to initialize flex delegate.
const char* filename_pywrap_tensorflow_internal =
#if defined(_WIN32)
"_pywrap_tensorflow_internal.pyd";
#else
"_pywrap_tensorflow_internal.so";
#endif
void* lib_tf_internal =
dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL);
SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
if (lib_tf_internal) {
auto TF_AcquireFlexDelegate =
reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
dlsym(lib_tf_internal, "TF_AcquireFlexDelegate"));
SharedLibrary::GetLibrarySymbol(lib_tf_internal,
"TF_AcquireFlexDelegate"));
if (TF_AcquireFlexDelegate) {
return TF_AcquireFlexDelegate();
}

View File

@ -28,6 +28,7 @@ cc_library(
":python_error_reporter",
":python_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite:shared_library",
"//tensorflow/lite:string_util",
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",

View File

@ -14,13 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead.
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif // defined(_WIN32)
#include <stdarg.h>
#include <sstream>
@ -36,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/shared_library.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/util.h"
@ -154,25 +148,13 @@ bool RegisterCustomOpByName(const char* registerer_name,
// Look for the Registerer function by name.
RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
// We don't have dlsym on Windows, use GetProcAddress instead.
#if defined(_WIN32)
GetProcAddress(nullptr, registerer_name)
#else
dlsym(RTLD_DEFAULT, registerer_name)
#endif // defined(_WIN32)
);
SharedLibrary::GetSymbol(registerer_name));
// Fail in an informative way if the function was not found.
if (registerer == nullptr) {
// We don't have dlerror on Windows, use GetLastError instead.
*error_msg =
#if defined(_WIN32)
absl::StrFormat("Looking up symbol '%s' failed with error (0x%x).",
registerer_name, GetLastError());
#else
absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
registerer_name, dlerror());
#endif // defined(_WIN32)
registerer_name, SharedLibrary::GetError());
return false;
}

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SHARED_LIBRARY_H_
#define TENSORFLOW_LITE_SHARED_LIBRARY_H_
#if defined(_WIN32)
// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead.
#include <windows.h>
#else
#include <dlfcn.h>
#endif // defined(_WIN32)
namespace tflite {
// SharedLibrary provides a uniform set of APIs across different platforms to
// handle dynamic library operations
class SharedLibrary {
public:
#if defined(_WIN32)
static inline void* LoadLibrary(const char* lib) {
return ::LoadLibrary(lib);
}
static inline void* GetLibrarySymbol(void* handle, const char* symbol) {
return static_cast<void*>(
GetProcAddress(static_cast<HMODULE>(handle), symbol));
}
static inline void* GetSymbol(const char* symbol) {
return static_cast<void*>(GetProcAddress(nullptr, symbol));
}
static inline int UnLoadLibrary(void* handle) {
return FreeLibrary(static_cast<HMODULE>(handle));
}
static inline const char* GetError() { return "Unknown"; }
#else
static inline void* LoadLibrary(const char* lib) {
return dlopen(lib, RTLD_LAZY | RTLD_LOCAL);
}
static inline void* GetLibrarySymbol(void* handle, const char* symbol) {
return dlsym(handle, symbol);
}
static inline void* GetSymbol(const char* symbol) {
return dlsym(RTLD_DEFAULT, symbol);
}
static inline int UnLoadLibrary(void* handle) { return dlclose(handle); }
static inline const char* GetError() { return dlerror(); }
#endif // defined(_WIN32)
};
} // namespace tflite
#endif // TENSORFLOW_LITE_SHARED_LIBRARY_H_