diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index b14041e8067..0b05a156442 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -36,7 +36,8 @@ struct PassConfig { form_clusters(false), inline_functions(true), unfold_batch_matmul(true), - legalize_tf_while(true) {} + legalize_tf_while(true), + saved_model_import(false) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. @@ -66,6 +67,10 @@ struct PassConfig { // Note: This is staging step and will be removed. // TODO(b/137395003): Remove post switching legalization. bool legalize_tf_while; + + // This flag indicates whether the TF program to be converted is being + // imported into MLIR via saved model import. + bool saved_model_import; }; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index b000de17020..358db8fa099 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -84,6 +84,18 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); } + if (pass_config.saved_model_import) { + // This pass does resource analysis of saved model global tensors and marks + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + // This pass marks non-exported functions as symbol visibility 'private' + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model:: + CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); + } + // Enable fusing composite ops that can be lowered to built-in TFLite ops. if (pass_config.emit_builtin_tflite_ops) { pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); @@ -114,6 +126,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, mlir::TFL::CreateLegalizeTFWhilePass()); } + if (pass_config.inline_functions) { + pass_manager->addPass(mlir::createInlinerPass()); + } + // TODO(jpienaar): Revise post dialect constants. pass_manager->addPass(mlir::TF::CreateDecodeConstantPass()); // Canonicalization includes const folding, which is utilized here to optimize @@ -121,9 +137,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // tf.Conv2D is split into tf.Transpose and tfl.Conv2D. pass_manager->addNestedPass(mlir::createCanonicalizerPass()); pass_manager->addNestedPass(mlir::createCSEPass()); - - if (pass_config.inline_functions) { - pass_manager->addPass(mlir::createInlinerPass()); + // This pass does dead code elimination based on symbol visibility. + pass_manager->addPass(mlir::createSymbolDCEPass()); + if (pass_config.saved_model_import) { + // This pass 'freezes' immutable global tensors and inlines them as tf + // constant ops. + pass_manager->addPass( + mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); } // The below passes only make sense if Builtin TFLite ops are enabled