diff --git a/tensorflow/lite/python/lite_mlir_test.py b/tensorflow/lite/python/lite_mlir_test.py index 98c0a5fe36e..f234eaf2301 100644 --- a/tensorflow/lite/python/lite_mlir_test.py +++ b/tensorflow/lite/python/lite_mlir_test.py @@ -40,28 +40,6 @@ from tensorflow.python.platform import test from tensorflow.python.training.tracking import tracking -def mlir_convert_and_check_for_unsupported(test_object, converter): - """Run the converter but don't fail MLIR was not built. - - Args: - test_object: PyTest object. - converter: A TFLiteConverter - - Returns: - The converted TF lite model or None if mlir support is not builtinto the - binary. - """ - try: - model = converter.convert() - test_object.assertTrue(model) - return model - except lite.ConverterError as e: - if not e.message.startswith('This flag is not supported by this version'): - raise e - else: - return None - - @test_util.run_v1_only('Incompatible with 2.0.') class FromSessionTest(test_util.TensorFlowTestCase): @@ -75,9 +53,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) converter.experimental_enable_mlir_converter = True - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) @@ -105,9 +81,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) @@ -144,9 +118,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): 'inputA': (0., 1.), 'inputB': (0., 1.) } # mean, std_dev - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) @@ -182,9 +154,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Test conversion with the scalar input shape. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) @@ -228,18 +198,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Convert float model. float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1], [out_tensor]) - float_tflite = mlir_convert_and_check_for_unsupported(self, float_converter) - if float_tflite is None: - return + float_tflite = float_converter.convert() # Convert quantized weights model. quantized_converter = lite.TFLiteConverter.from_session( sess, [in_tensor_1], [out_tensor]) quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_tflite = mlir_convert_and_check_for_unsupported( - self, quantized_converter) - if quantized_tflite is None: - return + quantized_tflite = quantized_converter.convert() # Ensure that the quantized weights tflite model is smaller. self.assertLess(len(quantized_tflite), len(float_tflite)) @@ -266,9 +231,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [placeholder], [output_node]) - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. interpreter = Interpreter(model_content=tflite_model) @@ -322,9 +285,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase): # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. expected_value = root.f(input_data) @@ -359,9 +320,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase): # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. expected_value = concrete_func(**input_data) @@ -389,9 +348,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase): # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. expected_value = concrete_func(input_data)[0] @@ -422,9 +379,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase): # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. expected_value = concrete_func(input_data) @@ -449,9 +404,7 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase): # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Check values from converted model. expected_value = concrete_func(input_data) @@ -478,9 +431,7 @@ class TestFlexMode(test_util.TensorFlowTestCase): [out_tensor]) converter.experimental_enable_mlir_converter = True converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Ensures the model contains TensorFlow ops. # TODO(nupurgarg): Check values once there is a Python delegate interface. @@ -505,10 +456,7 @@ class TestFlexMode(test_util.TensorFlowTestCase): converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) - - tflite_model = mlir_convert_and_check_for_unsupported(self, converter) - if tflite_model is None: - return + tflite_model = converter.convert() # Ensures the model contains TensorFlow ops. # TODO(nupurgarg): Check values once there is a Python delegate interface. diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 1f4e86f85c8..79357f66676 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") -load("//tensorflow:tensorflow.bzl", "if_mlir", "py_binary", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "if_mlir_tflite", "py_binary", "tf_py_test") package( default_visibility = [ @@ -22,7 +22,7 @@ cc_library( name = "toco_python_api", srcs = ["toco_python_api.cc"], hdrs = ["toco_python_api.h"], - defines = if_mlir( + defines = if_mlir_tflite( if_false = [], if_true = ["TFLITE_BUILD_WITH_MLIR_CONVERTER"], ), @@ -46,7 +46,7 @@ cc_library( "//tensorflow/core:ops", ], "//conditions:default": [], - }) + if_mlir( + }) + if_mlir_tflite( if_false = [], if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"], ), diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 5d9aba8637a..d253d5b8799 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2493,5 +2493,8 @@ def if_mlir(if_true, if_false = []): "//tensorflow:with_mlir_support": if_true, }) +def if_mlir_tflite(if_true, if_false = []): + return if_mlir(if_true, if_false) + def tfcompile_extra_flags(): return ""