diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index 8585790564b..771ad5e30d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -167,3 +167,16 @@ func @control_fetch(%arg0 : i32) { } return } + +// Check that @main function is pruned. +// CHECK-LABEL: func @main +func @main() { + tf_executor.graph { + // CHECK-NOT: tf_executor.island + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir new file mode 100644 index 00000000000..86568cccd0f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail + +// Check that @main function is skipped by default. +// CHECK-LABEL: func @main +func @main() { + tf_executor.graph { + // CHECKT: tf_executor.island + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 882e769ff4c..23cdebc4323 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -86,17 +86,36 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public FunctionPass { void runOnFunction() override { - getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); + FuncOp func = getFunction(); + if (func.getName() == "main" && skip_main_func) return; + func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); } + + struct Options : public PassOptions { + Option skip_main_func{ + *this, "skip-main-func", + llvm::cl::desc("skip graph pruning for main function"), + llvm::cl::init(false)}; + }; + + explicit GraphPruning(bool skip_main_func) + : FunctionPass(), skip_main_func(skip_main_func) {} + + explicit GraphPruning(const Options& option) + : GraphPruning(option.skip_main_func) {} + + private: + bool skip_main_func; }; } // namespace -std::unique_ptr> CreateTFExecutorGraphPruningPass() { - return std::make_unique(); +std::unique_ptr> CreateTFExecutorGraphPruningPass( + bool skip_main_func) { + return std::make_unique(skip_main_func); } -static PassRegistration pass( +static PassRegistration pass( "tf-executor-graph-pruning", "Prune unreachable nodes in a TensorFlow Graph."); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index fca1c02bc62..d8904949eb5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -79,7 +79,8 @@ std::unique_ptr> CreateSwitchFoldPass(); std::unique_ptr> CreateTFExecutorIslandCoarseningPass(); // Create a pass to prune tf_executor.graph from dead nodes. -std::unique_ptr> CreateTFExecutorGraphPruningPass(); +std::unique_ptr> CreateTFExecutorGraphPruningPass( + bool skip_main_func = false); // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph);