Enable flex delegate on tensorflow.lite.Interpreter Python package
Usually, flex delegate is enabled by symbol override of AcquireFlexDelegate() function. But this approach doesn't work well with shared library. Since pywrap_tensorflow_internal.so is available for tensorflow PIP, I've made the following changes to enable flex delegate. - Included flex delegate module to the pywrap_tensorflow_internal.so. This file already contains most TF internal logic and having TFLite flex delegate impacts about 72K to the output. - Added new function of TF_AcquireFlexDelegate() in the delegate module. - Updated logic in AcquireFlexDelegate() of interpreter_builder.cc to check the availability of pywrap_tensorflow_internal.so and lookup the TF_AcquireFlexDelegate() symbol to enable flex delegate. Also updated python/lite_flex_test.py since flex delegate is supported with Python API PiperOrigin-RevId: 317044994 Change-Id: Ic5e953f4a675b3f5360a4c7d607568193103711a
This commit is contained in:
parent
0a8019fa2b
commit
64e1b489bb
|
@ -136,3 +136,10 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle(
|
|||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
// Exported C interface function which is used by AcquireFlexDelegate() at
|
||||
// interpreter_build.cc. To export the function name globally, the function name
|
||||
// must be matched with patterns in tf_version_script.lds
|
||||
extern "C" tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() {
|
||||
return tflite::AcquireFlexDelegate();
|
||||
}
|
||||
|
|
|
@ -14,6 +14,9 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
#include <fcntl.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
@ -114,6 +117,20 @@ const char* kEmptyTensorName = "";
|
|||
// For flex delegate, see also the strong override in
|
||||
// lite/delegates/flex/delegate.cc.
|
||||
TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
|
||||
// If _pywrap_tensorflow_internal.so is available, use
|
||||
// TF_AcquireFlexDelegate() to initialize flex delegate.
|
||||
void* lib_tf_internal =
|
||||
dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL);
|
||||
if (lib_tf_internal) {
|
||||
auto TF_AcquireFlexDelegate =
|
||||
reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
|
||||
dlsym(lib_tf_internal, "TF_AcquireFlexDelegate"));
|
||||
if (TF_AcquireFlexDelegate) {
|
||||
return TF_AcquireFlexDelegate();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||
}
|
||||
|
||||
|
|
|
@ -193,9 +193,8 @@ py_test(
|
|||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
# TODO(b/111881877): Enable in oss after resolving op registry issues.
|
||||
"no_oss",
|
||||
"no_windows",
|
||||
"no_mac", # TODO(b/159077703): Enable Python API Flex support on MacOS.
|
||||
"no_windows", # TODO(b/159077703): Enable Python API Flex support on Windows.
|
||||
],
|
||||
deps = [
|
||||
":lite",
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
|
@ -41,8 +42,7 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
('DisableMlirConverter', False)) # disable mlir
|
||||
def testFlexMode(self, enable_mlir):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
|
@ -54,19 +54,22 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Ensures the model contains TensorFlow ops.
|
||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
||||
# Check the model works with TensorFlow ops.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
input_details = interpreter.get_input_details()
|
||||
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
def testDeprecatedFlags(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
|
@ -83,14 +86,18 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Ensures the model contains TensorFlow ops.
|
||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
||||
# Check the model works with TensorFlow ops.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
input_details = interpreter.get_input_details()
|
||||
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
|
||||
class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
|
||||
|
@ -114,14 +121,18 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
|
|||
converter.experimental_new_converter = enable_mlir
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Ensures the model contains TensorFlow ops.
|
||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
||||
# Check the model works with TensorFlow ops.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
interpreter.allocate_tensors()
|
||||
self.assertIn(
|
||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
||||
str(error.exception))
|
||||
input_details = interpreter.get_input_details()
|
||||
test_input = np.array([4.0], dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
expected_output = np.array([24.0], dtype=np.float32)
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -6058,7 +6058,12 @@ pywrap_tensorflow_macro(
|
|||
"@ngraph_tf//:ngraph_tf",
|
||||
]) + if_xla_available([
|
||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||
]),
|
||||
]) + select({
|
||||
"//tensorflow:windows": [], # TODO(b/159077703): Enable Flex on Windows
|
||||
"//conditions:default": [
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
# ** Targets for Windows build (start) **
|
||||
|
|
Loading…
Reference in New Issue