Enable flex delegate on tensorflow.lite.Interpreter Python package
Usually, flex delegate is enabled by symbol override of AcquireFlexDelegate() function. But this approach doesn't work well with shared library. Since pywrap_tensorflow_internal.so is available for tensorflow PIP, I've made the following changes to enable flex delegate. - Included flex delegate module to the pywrap_tensorflow_internal.so. This file already contains most TF internal logic and having TFLite flex delegate impacts about 72K to the output. - Added new function of TF_AcquireFlexDelegate() in the delegate module. - Updated logic in AcquireFlexDelegate() of interpreter_builder.cc to check the availability of pywrap_tensorflow_internal.so and lookup the TF_AcquireFlexDelegate() symbol to enable flex delegate. Also updated python/lite_flex_test.py since flex delegate is supported with Python API PiperOrigin-RevId: 317044994 Change-Id: Ic5e953f4a675b3f5360a4c7d607568193103711a
This commit is contained in:
parent
0a8019fa2b
commit
64e1b489bb
|
@ -136,3 +136,10 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle(
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
// Exported C interface function which is used by AcquireFlexDelegate() at
|
||||||
|
// interpreter_build.cc. To export the function name globally, the function name
|
||||||
|
// must be matched with patterns in tf_version_script.lds
|
||||||
|
extern "C" tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() {
|
||||||
|
return tflite::AcquireFlexDelegate();
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,9 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/interpreter_builder.h"
|
#include "tensorflow/lite/interpreter_builder.h"
|
||||||
|
|
||||||
|
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#endif
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
@ -114,6 +117,20 @@ const char* kEmptyTensorName = "";
|
||||||
// For flex delegate, see also the strong override in
|
// For flex delegate, see also the strong override in
|
||||||
// lite/delegates/flex/delegate.cc.
|
// lite/delegates/flex/delegate.cc.
|
||||||
TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
|
TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
|
||||||
|
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
|
||||||
|
// If _pywrap_tensorflow_internal.so is available, use
|
||||||
|
// TF_AcquireFlexDelegate() to initialize flex delegate.
|
||||||
|
void* lib_tf_internal =
|
||||||
|
dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL);
|
||||||
|
if (lib_tf_internal) {
|
||||||
|
auto TF_AcquireFlexDelegate =
|
||||||
|
reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
|
||||||
|
dlsym(lib_tf_internal, "TF_AcquireFlexDelegate"));
|
||||||
|
if (TF_AcquireFlexDelegate) {
|
||||||
|
return TF_AcquireFlexDelegate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -193,9 +193,8 @@ py_test(
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
# TODO(b/111881877): Enable in oss after resolving op registry issues.
|
"no_mac", # TODO(b/159077703): Enable Python API Flex support on MacOS.
|
||||||
"no_oss",
|
"no_windows", # TODO(b/159077703): Enable Python API Flex support on Windows.
|
||||||
"no_windows",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":lite",
|
":lite",
|
||||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
|
@ -41,8 +42,7 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
('DisableMlirConverter', False)) # disable mlir
|
('DisableMlirConverter', False)) # disable mlir
|
||||||
def testFlexMode(self, enable_mlir):
|
def testFlexMode(self, enable_mlir):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
in_tensor = array_ops.placeholder(
|
in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
|
||||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
|
||||||
out_tensor = in_tensor + in_tensor
|
out_tensor = in_tensor + in_tensor
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
|
|
||||||
|
@ -54,19 +54,22 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
# Ensures the model contains TensorFlow ops.
|
# Check the model works with TensorFlow ops.
|
||||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
|
||||||
interpreter = Interpreter(model_content=tflite_model)
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
with self.assertRaises(RuntimeError) as error:
|
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
self.assertIn(
|
input_details = interpreter.get_input_details()
|
||||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||||
str(error.exception))
|
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
|
||||||
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
|
self.assertTrue((expected_output == output_data).all())
|
||||||
|
|
||||||
def testDeprecatedFlags(self):
|
def testDeprecatedFlags(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
in_tensor = array_ops.placeholder(
|
in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
|
||||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
|
||||||
out_tensor = in_tensor + in_tensor
|
out_tensor = in_tensor + in_tensor
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
|
|
||||||
|
@ -83,14 +86,18 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
# Ensures the model contains TensorFlow ops.
|
# Check the model works with TensorFlow ops.
|
||||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
|
||||||
interpreter = Interpreter(model_content=tflite_model)
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
with self.assertRaises(RuntimeError) as error:
|
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
self.assertIn(
|
input_details = interpreter.get_input_details()
|
||||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||||
str(error.exception))
|
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
|
||||||
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
|
self.assertTrue((expected_output == output_data).all())
|
||||||
|
|
||||||
|
|
||||||
class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
|
class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
|
||||||
|
@ -114,14 +121,18 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
|
||||||
converter.experimental_new_converter = enable_mlir
|
converter.experimental_new_converter = enable_mlir
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
# Ensures the model contains TensorFlow ops.
|
# Check the model works with TensorFlow ops.
|
||||||
# TODO(nupurgarg): Check values once there is a Python delegate interface.
|
|
||||||
interpreter = Interpreter(model_content=tflite_model)
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
with self.assertRaises(RuntimeError) as error:
|
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
self.assertIn(
|
input_details = interpreter.get_input_details()
|
||||||
'Regular TensorFlow ops are not supported by this interpreter.',
|
test_input = np.array([4.0], dtype=np.float32)
|
||||||
str(error.exception))
|
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
expected_output = np.array([24.0], dtype=np.float32)
|
||||||
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
|
self.assertTrue((expected_output == output_data).all())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -6058,7 +6058,12 @@ pywrap_tensorflow_macro(
|
||||||
"@ngraph_tf//:ngraph_tf",
|
"@ngraph_tf//:ngraph_tf",
|
||||||
]) + if_xla_available([
|
]) + if_xla_available([
|
||||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||||
]),
|
]) + select({
|
||||||
|
"//tensorflow:windows": [], # TODO(b/159077703): Enable Flex on Windows
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/lite/delegates/flex:delegate",
|
||||||
|
],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ** Targets for Windows build (start) **
|
# ** Targets for Windows build (start) **
|
||||||
|
|
Loading…
Reference in New Issue