diff --git a/tensorflow/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc index 4741bddc2f5..b8b0d4e6d01 100644 --- a/tensorflow/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -136,3 +136,10 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle( } } // namespace tflite + +// Exported C interface function which is used by AcquireFlexDelegate() at +// interpreter_build.cc. To export the function name globally, the function name +// must be matched with patterns in tf_version_script.lds +extern "C" tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() { + return tflite::AcquireFlexDelegate(); +} diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 43d81ef0770..d73b298e595 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/interpreter_builder.h" +#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32) +#include +#endif #include #include #include @@ -114,6 +117,20 @@ const char* kEmptyTensorName = ""; // For flex delegate, see also the strong override in // lite/delegates/flex/delegate.cc. TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { +#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_WIN32) + // If _pywrap_tensorflow_internal.so is available, use + // TF_AcquireFlexDelegate() to initialize flex delegate. + void* lib_tf_internal = + dlopen("_pywrap_tensorflow_internal.so", RTLD_NOW | RTLD_LOCAL); + if (lib_tf_internal) { + auto TF_AcquireFlexDelegate = + reinterpret_cast( + dlsym(lib_tf_internal, "TF_AcquireFlexDelegate")); + if (TF_AcquireFlexDelegate) { + return TF_AcquireFlexDelegate(); + } + } +#endif return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index d25e7d5ef8d..1b64b7d1042 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -193,9 +193,8 @@ py_test( python_version = "PY3", srcs_version = "PY2AND3", tags = [ - # TODO(b/111881877): Enable in oss after resolving op registry issues. - "no_oss", - "no_windows", + "no_mac", # TODO(b/159077703): Enable Python API Flex support on MacOS. + "no_windows", # TODO(b/159077703): Enable Python API Flex support on Windows. ], deps = [ ":lite", diff --git a/tensorflow/lite/python/lite_flex_test.py b/tensorflow/lite/python/lite_flex_test.py index 26bee206d27..ffc157c2128 100644 --- a/tensorflow/lite/python/lite_flex_test.py +++ b/tensorflow/lite/python/lite_flex_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized +import numpy as np from tensorflow.lite.python import lite from tensorflow.lite.python.interpreter import Interpreter @@ -41,8 +42,7 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase): ('DisableMlirConverter', False)) # disable mlir def testFlexMode(self, enable_mlir): with ops.Graph().as_default(): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32) + in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() @@ -54,19 +54,22 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase): tflite_model = converter.convert() self.assertTrue(tflite_model) - # Ensures the model contains TensorFlow ops. - # TODO(nupurgarg): Check values once there is a Python delegate interface. + # Check the model works with TensorFlow ops. interpreter = Interpreter(model_content=tflite_model) - with self.assertRaises(RuntimeError) as error: - interpreter.allocate_tensors() - self.assertIn( - 'Regular TensorFlow ops are not supported by this interpreter.', - str(error.exception)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_details = interpreter.get_output_details() + expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32) + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) def testDeprecatedFlags(self): with ops.Graph().as_default(): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32) + in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() @@ -83,14 +86,18 @@ class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase): tflite_model = converter.convert() self.assertTrue(tflite_model) - # Ensures the model contains TensorFlow ops. - # TODO(nupurgarg): Check values once there is a Python delegate interface. + # Check the model works with TensorFlow ops. interpreter = Interpreter(model_content=tflite_model) - with self.assertRaises(RuntimeError) as error: - interpreter.allocate_tensors() - self.assertIn( - 'Regular TensorFlow ops are not supported by this interpreter.', - str(error.exception)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_details = interpreter.get_output_details() + expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32) + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) class FromConcreteFunctionTest(test_util.TensorFlowTestCase, @@ -114,14 +121,18 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase, converter.experimental_new_converter = enable_mlir tflite_model = converter.convert() - # Ensures the model contains TensorFlow ops. - # TODO(nupurgarg): Check values once there is a Python delegate interface. + # Check the model works with TensorFlow ops. interpreter = Interpreter(model_content=tflite_model) - with self.assertRaises(RuntimeError) as error: - interpreter.allocate_tensors() - self.assertIn( - 'Regular TensorFlow ops are not supported by this interpreter.', - str(error.exception)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + test_input = np.array([4.0], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_details = interpreter.get_output_details() + expected_output = np.array([24.0], dtype=np.float32) + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) if __name__ == '__main__': diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index d141b719aef..f53859b2915 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -6058,7 +6058,12 @@ pywrap_tensorflow_macro( "@ngraph_tf//:ngraph_tf", ]) + if_xla_available([ "//tensorflow/compiler/aot:tfcompile_lib", - ]), + ]) + select({ + "//tensorflow:windows": [], # TODO(b/159077703): Enable Flex on Windows + "//conditions:default": [ + "//tensorflow/lite/delegates/flex:delegate", + ], + }), ) # ** Targets for Windows build (start) **