Enable flex delegate on tensorflow.lite.Interpreter Python package

Usually, flex delegate is enabled by symbol override of AcquireFlexDelegate()
function. But this approach doesn't work well with shared library.

Since pywrap_tensorflow_internal.so is available for tensorflow PIP,
I've made the following changes to enable flex delegate.
- Included flex delegate module to the pywrap_tensorflow_internal.so.
  This file already contains most TF internal logic and having TFLite flex
  delegate impacts about 72K to the output.
- Added new function of TF_AcquireFlexDelegate() in the delegate module.
- Updated logic in AcquireFlexDelegate() of interpreter_builder.cc to check
  the availability of pywrap_tensorflow_internal.so and lookup the
  TF_AcquireFlexDelegate() symbol to enable flex delegate.

Also updated python/lite_flex_test.py since flex delegate is supported with
Python API

PiperOrigin-RevId: 317044994
Change-Id: Ic5e953f4a675b3f5360a4c7d607568193103711a
This commit is contained in:
Terry Heo 2020-06-17 23:54:57 -07:00 committed by TensorFlower Gardener
parent 0a8019fa2b
commit 64e1b489bb
5 changed files with 68 additions and 29 deletions

View File

@ -136,3 +136,10 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle(
} }
} // namespace tflite } // namespace tflite
// Exported C interface function which is used by AcquireFlexDelegate() at
// interpreter_build.cc. To export the function name globally, the function name
// must be matched with patterns in tf_version_script.lds
extern "C" tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() {
return tflite::AcquireFlexDelegate();
}

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/interpreter_builder.h"
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
#include <dlfcn.h>
#endif
#include <fcntl.h> #include <fcntl.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@ -114,6 +117,20 @@ const char* kEmptyTensorName = "";
// For flex delegate, see also the strong override in // For flex delegate, see also the strong override in
// lite/delegates/flex/delegate.cc. // lite/delegates/flex/delegate.cc.
TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32)
// If _pywrap_tensorflow_internal.so is available, use
// TF_AcquireFlexDelegate() to initialize flex delegate.
void* lib_tf_internal =
dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL);
if (lib_tf_internal) {
auto TF_AcquireFlexDelegate =
reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
dlsym(lib_tf_internal, "TF_AcquireFlexDelegate"));
if (TF_AcquireFlexDelegate) {
return TF_AcquireFlexDelegate();
}
}
#endif
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
} }

View File

@ -193,9 +193,8 @@ py_test(
python_version = "PY3", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [
# TODO(b/111881877): Enable in oss after resolving op registry issues. "no_mac", # TODO(b/159077703): Enable Python API Flex support on MacOS.
"no_oss", "no_windows", # TODO(b/159077703): Enable Python API Flex support on Windows.
"no_windows",
], ],
deps = [ deps = [
":lite", ":lite",

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.python.interpreter import Interpreter
@ -41,8 +42,7 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
('DisableMlirConverter', False)) # disable mlir ('DisableMlirConverter', False)) # disable mlir
def testFlexMode(self, enable_mlir): def testFlexMode(self, enable_mlir):
with ops.Graph().as_default(): with ops.Graph().as_default():
in_tensor = array_ops.placeholder( in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor out_tensor = in_tensor + in_tensor
sess = session.Session() sess = session.Session()
@ -54,19 +54,22 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
# Ensures the model contains TensorFlow ops. # Check the model works with TensorFlow ops.
# TODO(nupurgarg): Check values once there is a Python delegate interface.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
with self.assertRaises(RuntimeError) as error:
interpreter.allocate_tensors() interpreter.allocate_tensors()
self.assertIn( input_details = interpreter.get_input_details()
'Regular TensorFlow ops are not supported by this interpreter.', test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
str(error.exception)) interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testDeprecatedFlags(self): def testDeprecatedFlags(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
in_tensor = array_ops.placeholder( in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor out_tensor = in_tensor + in_tensor
sess = session.Session() sess = session.Session()
@ -83,14 +86,18 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
# Ensures the model contains TensorFlow ops. # Check the model works with TensorFlow ops.
# TODO(nupurgarg): Check values once there is a Python delegate interface.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
with self.assertRaises(RuntimeError) as error:
interpreter.allocate_tensors() interpreter.allocate_tensors()
self.assertIn( input_details = interpreter.get_input_details()
'Regular TensorFlow ops are not supported by this interpreter.', test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
str(error.exception)) interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
class FromConcreteFunctionTest(test_util.TensorFlowTestCase, class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
@ -114,14 +121,18 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
converter.experimental_new_converter = enable_mlir converter.experimental_new_converter = enable_mlir
tflite_model = converter.convert() tflite_model = converter.convert()
# Ensures the model contains TensorFlow ops. # Check the model works with TensorFlow ops.
# TODO(nupurgarg): Check values once there is a Python delegate interface.
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
with self.assertRaises(RuntimeError) as error:
interpreter.allocate_tensors() interpreter.allocate_tensors()
self.assertIn( input_details = interpreter.get_input_details()
'Regular TensorFlow ops are not supported by this interpreter.', test_input = np.array([4.0], dtype=np.float32)
str(error.exception)) interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([24.0], dtype=np.float32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -6058,7 +6058,12 @@ pywrap_tensorflow_macro(
"@ngraph_tf//:ngraph_tf", "@ngraph_tf//:ngraph_tf",
]) + if_xla_available([ ]) + if_xla_available([
"//tensorflow/compiler/aot:tfcompile_lib", "//tensorflow/compiler/aot:tfcompile_lib",
]), ]) + select({
"//tensorflow:windows": [], # TODO(b/159077703): Enable Flex on Windows
"//conditions:default": [
"//tensorflow/lite/delegates/flex:delegate",
],
}),
) )
# ** Targets for Windows build (start) ** # ** Targets for Windows build (start) **