diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index 771ad5e30d8..8585790564b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -167,16 +167,3 @@ func @control_fetch(%arg0 : i32) { } return } - -// Check that @main function is pruned. -// CHECK-LABEL: func @main -func @main() { - tf_executor.graph { - // CHECK-NOT: tf_executor.island - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir deleted file mode 100644 index 86568cccd0f..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning_skip_main.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail - -// Check that @main function is skipped by default. -// CHECK-LABEL: func @main -func @main() { - tf_executor.graph { - // CHECKT: tf_executor.island - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 23cdebc4323..882e769ff4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -86,36 +86,17 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public FunctionPass { void runOnFunction() override { - FuncOp func = getFunction(); - if (func.getName() == "main" && skip_main_func) return; - func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); + getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); } - - struct Options : public PassOptions { - Option skip_main_func{ - *this, "skip-main-func", - llvm::cl::desc("skip graph pruning for main function"), - llvm::cl::init(false)}; - }; - - explicit GraphPruning(bool skip_main_func) - : FunctionPass(), skip_main_func(skip_main_func) {} - - explicit GraphPruning(const Options& option) - : GraphPruning(option.skip_main_func) {} - - private: - bool skip_main_func; }; } // namespace -std::unique_ptr> CreateTFExecutorGraphPruningPass( - bool skip_main_func) { - return std::make_unique(skip_main_func); +std::unique_ptr> CreateTFExecutorGraphPruningPass() { + return std::make_unique(); } -static PassRegistration pass( +static PassRegistration pass( "tf-executor-graph-pruning", "Prune unreachable nodes in a TensorFlow Graph."); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index c9c97735848..f870ca298d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -79,8 +79,7 @@ std::unique_ptr> CreateSwitchFoldPass(); std::unique_ptr> CreateTFExecutorIslandCoarseningPass(); // Create a pass to prune tf_executor.graph from dead nodes. -std::unique_ptr> CreateTFExecutorGraphPruningPass( - bool skip_main_func = false); +std::unique_ptr> CreateTFExecutorGraphPruningPass(); // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 4e914a5a20d..b4dfb91e6b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -216,8 +216,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // and canonicalization opportunities that are necessary for the second // LegalizeTFPass(allow_partial_conversion=false) invocation. tf2xla.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(true)); - tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass( - /*skip_main_func=*/true)); + tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addNestedPass( mlir::xla_hlo::createLegalizeTFPass(false));