diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 5eefa821c6b..6ab16141626 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -73,16 +73,17 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); } + if (pass_config.shape_inference) { + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } + // Keep this pass after the shape inference pass, which couldn't do shape + // inference for non-tf ops. if (!pass_config.quant_specs.serialized_quant_stats.empty()) { pass_manager->addPass( mlir::quant::CreateImportQuantStatsPassForTFControlDialect( pass_config.quant_specs.serialized_quant_stats)); } - if (pass_config.shape_inference) { - pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); - } - // The conversion pipeline has to follow the following orders: // 1) Saved model related optimization like decompose resource ops // 2) Convert composite functions like lstm/rnns, along with proper function