From 6506145ebb079c7898dbb9861a57721770aeaf99 Mon Sep 17 00:00:00 2001 From: Ashwin Murthy Date: Thu, 5 Mar 2020 11:34:41 -0800 Subject: [PATCH] Remove saved_model_import flag guard on pass config Since saved model passes are a no-op for graphdef import and safe to always be included. PiperOrigin-RevId: 299151617 Change-Id: If302d005b901ff3c5646784ac8f2a0ba9e372490 --- .../mlir/lite/common/tfl_pass_config.h | 7 +---- .../compiler/mlir/lite/tf_tfl_passes.cc | 29 ++++++++----------- .../compiler/mlir/lite/tf_tfl_translate.cc | 2 -- 3 files changed, 13 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 0b05a156442..b14041e8067 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -36,8 +36,7 @@ struct PassConfig { form_clusters(false), inline_functions(true), unfold_batch_matmul(true), - legalize_tf_while(true), - saved_model_import(false) {} + legalize_tf_while(true) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. @@ -67,10 +66,6 @@ struct PassConfig { // Note: This is staging step and will be removed. // TODO(b/137395003): Remove post switching legalization. bool legalize_tf_while; - - // This flag indicates whether the TF program to be converted is being - // imported into MLIR via saved model import. - bool saved_model_import; }; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 358db8fa099..8edeb827528 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -84,17 +84,15 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); } - if (pass_config.saved_model_import) { - // This pass does resource analysis of saved model global tensors and marks - // those deemed read-only as immutable. - pass_manager->addPass( - mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); - // This pass marks non-exported functions as symbol visibility 'private' - // those deemed read-only as immutable. - pass_manager->addPass( - mlir::tf_saved_model:: - CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); - } + // This pass does resource analysis of saved model global tensors and marks + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + // This pass marks non-exported functions as symbol visibility 'private' + // those deemed read-only as immutable. + pass_manager->addPass( + mlir::tf_saved_model:: + CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); // Enable fusing composite ops that can be lowered to built-in TFLite ops. if (pass_config.emit_builtin_tflite_ops) { @@ -139,12 +137,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addNestedPass(mlir::createCSEPass()); // This pass does dead code elimination based on symbol visibility. pass_manager->addPass(mlir::createSymbolDCEPass()); - if (pass_config.saved_model_import) { - // This pass 'freezes' immutable global tensors and inlines them as tf - // constant ops. - pass_manager->addPass( - mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); - } + // This pass 'freezes' immutable global tensors and inlines them as tf + // constant ops. + pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); // The below passes only make sense if Builtin TFLite ops are enabled // for emission. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 21712e79065..7f8ce4cf3d4 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -195,8 +195,6 @@ int main(int argc, char **argv) { pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = lower_tensor_list_ops; pass_config.inline_functions = inline_functions; - if (import_saved_model || import_saved_model_v1) - pass_config.saved_model_import = true; tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());