diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 62b226cf852..99740515a48 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -521,6 +521,7 @@ cc_library( ":tensorflow_lite_quantize", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 25d15614ef6..8e8d102994b 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -41,6 +41,7 @@ bool ShouldRunQuantizePasses(mlir::ModuleOp m) { void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, mlir::PassManager* pass_manager) { + pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass()); pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion()); pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); // Ophint extraction will happen after island extraction pass. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 92087de44af..5e29a00d7c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -58,7 +58,7 @@ limitations under the License. namespace mlir { namespace { -class SwitchFold : public mlir::FunctionPass { +class SwitchFoldPass : public mlir::FunctionPass { public: void runOnFunction() override; }; @@ -266,7 +266,7 @@ bool HasSendOrReceive(FuncOp function) { .wasInterrupted(); } -void SwitchFold::runOnFunction() { +void SwitchFoldPass::runOnFunction() { if (HasSendOrReceive(getFunction())) return; DeadQueue queue; // Initialize dead queue with dead outputs of foldable SwitchOps. @@ -277,7 +277,13 @@ void SwitchFold::runOnFunction() { if (failed(FoldMergeNodes(getFunction(), queue))) return signalPassFailure(); } // namespace mlir -static PassRegistration pass( +namespace tf_executor { +std::unique_ptr CreateSwitchFoldPass() { + return std::make_unique(); +} +} // namespace tf_executor + +static PassRegistration pass( "tf-switch-fold", "Fold switch nodes with constant predicates"); } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 0306abd55cb..6489677043f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -41,6 +41,9 @@ std::unique_ptr CreateRaiseTFControlFlowPass(); namespace tf_executor { class GraphOp; +// Returns a pass that folds switch nodes with constant predicates. +std::unique_ptr CreateSwitchFoldPass(); + // Create a pass to merge IslandOps from TFExecutor dialect. std::unique_ptr CreateTFExecutorIslandCoarseningPass();