tflite: Add SharedLibrary class
Introduce SharedLibrary class to handle shared library regardless of OS type. Applied this class to flex delegate, external delegate and interpreter_wrapper. PiperOrigin-RevId: 320307902 Change-Id: Iab7b56599eb20e56bd57d59d4c52a2ff7d571604
This commit is contained in:
parent
f5e22e5e0e
commit
af54c4e33b
tensorflow/lite
@ -246,6 +246,7 @@ cc_library(
|
||||
":graph_info",
|
||||
":memory_planner",
|
||||
":minimal_logging",
|
||||
":shared_library",
|
||||
":simple_memory_arena",
|
||||
":string",
|
||||
":tflite_with_xnnpack_optional",
|
||||
@ -635,6 +636,13 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "shared_library",
|
||||
hdrs = ["shared_library.h"],
|
||||
copts = TFLITE_DEFAULT_COPTS,
|
||||
linkopts = if_not_windows(["-ldl"]),
|
||||
)
|
||||
|
||||
# Shared lib target for convenience, pulls in the core runtime and builtin ops.
|
||||
# Note: This target is not yet finalized, and the exact set of exported (C/C++)
|
||||
# APIs is subject to change. The output library name is platform dependent:
|
||||
|
1
tensorflow/lite/delegates/external/BUILD
vendored
1
tensorflow/lite/delegates/external/BUILD
vendored
@ -26,6 +26,7 @@ cc_library(
|
||||
hdrs = ["external_delegate.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:shared_library",
|
||||
"//tensorflow/lite/c:common",
|
||||
],
|
||||
)
|
||||
|
@ -17,42 +17,12 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <Windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/shared_library.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
// Library Support construct to handle dynamic library operations
|
||||
#if defined(_WIN32)
|
||||
struct LibSupport {
|
||||
static void* Load(const char* lib) { return LoadLibrary(lib); }
|
||||
|
||||
static void* GetSymbol(void* handle, const char* symbol) {
|
||||
return (void*)GetProcAddress((HMODULE)handle, symbol);
|
||||
}
|
||||
|
||||
static int UnLoad(void* handle) { return FreeLibrary((HMODULE)handle); }
|
||||
};
|
||||
#else
|
||||
struct LibSupport {
|
||||
static void* Load(const char* lib) {
|
||||
return dlopen(lib, RTLD_LAZY | RTLD_LOCAL);
|
||||
}
|
||||
|
||||
static void* GetSymbol(void* handle, const char* symbol) {
|
||||
return dlsym(handle, symbol);
|
||||
}
|
||||
|
||||
static int UnLoad(void* handle) { return dlclose(handle); }
|
||||
};
|
||||
#endif
|
||||
|
||||
// External delegate library construct
|
||||
struct ExternalLib {
|
||||
using CreateDelegatePtr = std::add_pointer<TfLiteDelegate*(
|
||||
@ -62,15 +32,17 @@ struct ExternalLib {
|
||||
|
||||
// Open a given delegate library and load the create/destroy symbols
|
||||
bool load(const std::string library) {
|
||||
void* handle = LibSupport::Load(library.c_str());
|
||||
void* handle = SharedLibrary::LoadLibrary(library.c_str());
|
||||
if (handle == nullptr) {
|
||||
TFLITE_LOG(TFLITE_LOG_INFO, "Unable to load external delegate from : %s",
|
||||
library.c_str());
|
||||
} else {
|
||||
create = reinterpret_cast<decltype(create)>(
|
||||
LibSupport::GetSymbol(handle, "tflite_plugin_create_delegate"));
|
||||
destroy = reinterpret_cast<decltype(destroy)>(
|
||||
LibSupport::GetSymbol(handle, "tflite_plugin_destroy_delegate"));
|
||||
create =
|
||||
reinterpret_cast<decltype(create)>(SharedLibrary::GetLibrarySymbol(
|
||||
handle, "tflite_plugin_create_delegate"));
|
||||
destroy =
|
||||
reinterpret_cast<decltype(destroy)>(SharedLibrary::GetLibrarySymbol(
|
||||
handle, "tflite_plugin_destroy_delegate"));
|
||||
return create && destroy;
|
||||
}
|
||||
return false;
|
||||
|
@ -14,9 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
#include <fcntl.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
@ -31,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/shared_library.h"
|
||||
#include "tensorflow/lite/tflite_with_xnnpack_optional.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
@ -117,15 +115,22 @@ const char* kEmptyTensorName = "";
|
||||
// For flex delegate, see also the strong override in
|
||||
// lite/delegates/flex/delegate.cc.
|
||||
TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
|
||||
#if !defined(__ANDROID__)
|
||||
// If _pywrap_tensorflow_internal.so is available, use
|
||||
// TF_AcquireFlexDelegate() to initialize flex delegate.
|
||||
const char* filename_pywrap_tensorflow_internal =
|
||||
#if defined(_WIN32)
|
||||
"_pywrap_tensorflow_internal.pyd";
|
||||
#else
|
||||
"_pywrap_tensorflow_internal.so";
|
||||
#endif
|
||||
void* lib_tf_internal =
|
||||
dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL);
|
||||
SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
|
||||
if (lib_tf_internal) {
|
||||
auto TF_AcquireFlexDelegate =
|
||||
reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
|
||||
dlsym(lib_tf_internal, "TF_AcquireFlexDelegate"));
|
||||
SharedLibrary::GetLibrarySymbol(lib_tf_internal,
|
||||
"TF_AcquireFlexDelegate"));
|
||||
if (TF_AcquireFlexDelegate) {
|
||||
return TF_AcquireFlexDelegate();
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ cc_library(
|
||||
":python_error_reporter",
|
||||
":python_utils",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:shared_library",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
|
@ -14,13 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
||||
|
||||
// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead.
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <sstream>
|
||||
@ -36,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/shared_library.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
@ -154,25 +148,13 @@ bool RegisterCustomOpByName(const char* registerer_name,
|
||||
|
||||
// Look for the Registerer function by name.
|
||||
RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
|
||||
// We don't have dlsym on Windows, use GetProcAddress instead.
|
||||
#if defined(_WIN32)
|
||||
GetProcAddress(nullptr, registerer_name)
|
||||
#else
|
||||
dlsym(RTLD_DEFAULT, registerer_name)
|
||||
#endif // defined(_WIN32)
|
||||
);
|
||||
SharedLibrary::GetSymbol(registerer_name));
|
||||
|
||||
// Fail in an informative way if the function was not found.
|
||||
if (registerer == nullptr) {
|
||||
// We don't have dlerror on Windows, use GetLastError instead.
|
||||
*error_msg =
|
||||
#if defined(_WIN32)
|
||||
absl::StrFormat("Looking up symbol '%s' failed with error (0x%x).",
|
||||
registerer_name, GetLastError());
|
||||
#else
|
||||
absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
|
||||
registerer_name, dlerror());
|
||||
#endif // defined(_WIN32)
|
||||
registerer_name, SharedLibrary::GetError());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
63
tensorflow/lite/shared_library.h
Normal file
63
tensorflow/lite/shared_library.h
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SHARED_LIBRARY_H_
|
||||
#define TENSORFLOW_LITE_SHARED_LIBRARY_H_
|
||||
|
||||
#if defined(_WIN32)
|
||||
// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead.
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// SharedLibrary provides a uniform set of APIs across different platforms to
|
||||
// handle dynamic library operations
|
||||
class SharedLibrary {
|
||||
public:
|
||||
#if defined(_WIN32)
|
||||
static inline void* LoadLibrary(const char* lib) {
|
||||
return ::LoadLibrary(lib);
|
||||
}
|
||||
static inline void* GetLibrarySymbol(void* handle, const char* symbol) {
|
||||
return static_cast<void*>(
|
||||
GetProcAddress(static_cast<HMODULE>(handle), symbol));
|
||||
}
|
||||
static inline void* GetSymbol(const char* symbol) {
|
||||
return static_cast<void*>(GetProcAddress(nullptr, symbol));
|
||||
}
|
||||
static inline int UnLoadLibrary(void* handle) {
|
||||
return FreeLibrary(static_cast<HMODULE>(handle));
|
||||
}
|
||||
static inline const char* GetError() { return "Unknown"; }
|
||||
#else
|
||||
static inline void* LoadLibrary(const char* lib) {
|
||||
return dlopen(lib, RTLD_LAZY | RTLD_LOCAL);
|
||||
}
|
||||
static inline void* GetLibrarySymbol(void* handle, const char* symbol) {
|
||||
return dlsym(handle, symbol);
|
||||
}
|
||||
static inline void* GetSymbol(const char* symbol) {
|
||||
return dlsym(RTLD_DEFAULT, symbol);
|
||||
}
|
||||
static inline int UnLoadLibrary(void* handle) { return dlclose(handle); }
|
||||
static inline const char* GetError() { return dlerror(); }
|
||||
#endif // defined(_WIN32)
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_SHARED_LIBRARY_H_
|
Loading…
Reference in New Issue
Block a user