diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 97d3f2a1ec6..c997496bdea 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -302,7 +302,9 @@ class TFLiteConverterBase(object): """ if not optimizers: optimizers = [] - optimizers.append("constfold") + # MLIR converter will take care of constant folding instead of grappler. + if not self.experimental_new_converter: + optimizers.append("constfold") is_only_flex_enabled = ( set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops)) @@ -595,12 +597,17 @@ class TFLiteConverterV2(TFLiteConverterBase): output_tensors = frozen_func.outputs # Run a Grappler pass. - graph_def = _run_graph_optimizations( - graph_def, - input_tensors, - output_tensors, - config=self._grappler_config(), - graph=frozen_func.graph) + grappler_config = self._grappler_config() + # Skip running grappler when there are no optimizers to run. If not, + # grappler will run with the default optimizer set and it will lead to + # causing an unexpected behavior. + if grappler_config.graph_options.rewrite_options.optimizers: + graph_def = _run_graph_optimizations( + graph_def, + input_tensors, + output_tensors, + config=grappler_config, + graph=frozen_func.graph) quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, graph_def) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index e0ae7aa4943..6ec70cf0f20 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -2236,6 +2236,37 @@ class GrapplerTest(TestModels, parameterized.TestCase): self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual('Const', input_details[1]['name']) + def testGrapplerConstFolding(self): + # Constant folding converts the following add operation to tf.broadcast_to + # operation which was not supported by the TFLite at the time this test was + # added. + @def_function.function + def plus_placeholder(x, placeholder): + return x + placeholder + + with ops.Graph().as_default(): + in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) + out_tensor = plus_placeholder( + array_ops.zeros([2, 2, 2]), + array_ops.reshape(in_tensor, shape=[2, 2])) + sess = session.Session() + + # Convert model. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + # Only disable this path in MLIR conversion for toco compatibility. + converter.experimental_new_converter = True + tflite_model = converter.convert() + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual('Placeholder', input_details[0]['name']) + + class ImportOpsUtilTest(LiteTest): def testGetPotentiallySupportedOps(self):