diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index 8585790564b..042cdaf5820 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -18,17 +18,15 @@ func @chained_islands(%arg0 : i32) -> i32 { return %0 : i32 } -// Check that empty islands that don't contribute to the fetch are removed. +// Check that a function that does not have arguments/results is ignored by +// thep pruning pass: this could be a V1 graph imported without feeds/fetches. // CHECK-LABEL: func @empty_islands( func @empty_islands() { -// CHECK-NOT: tf_executor.island +// CHECK: tf_executor.island tf_executor.graph { %0 = tf_executor.island { tf_executor.yield } - %1 = tf_executor.island { - tf_executor.yield - } tf_executor.fetch } return @@ -87,7 +85,7 @@ func @nextiteration_deleted(%arg0 : i32) -> i32 { // Check that NextIteration.source/sink ops and associated ops are deleted when // associated loop is unreachable. // CHECK-LABEL: func @unreachable_loop -func @unreachable_loop() { +func @unreachable_loop(%arg0 : i32) { // CHECK: tf_executor.graph // CHECK-NEXT: tf_executor.fetch tf_executor.graph { @@ -104,7 +102,7 @@ func @unreachable_loop() { %10:2 = tf_executor.island(%9#1) wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor} : () -> tensor %11:2 = tf_executor.island wraps "tf.Add"(%9#0, %10#0) {T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor) -> tensor<*xi32> tf_executor.NextIteration.Sink [%0#1] %11#0 : tensor<*xi32> {T = "tfdtype$DT_INT32"} - tf_executor.fetch + tf_executor.fetch %arg0 : i32 } return } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index c7dac93101b..d52c49e4436 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -86,7 +86,15 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public FunctionPass { void runOnFunction() override { - getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); + getFunction().walk([](tf_executor::GraphOp graph) { + // For TensorFlow V1.0 compatibility: when importing a graph without + // providing feeds/fetches we should not attempt to prune. The best + // approximation here is to check if the graph does not have any fetched + // values. + if (!graph.GetFetch().getNumOperands()) return; + + PruneGraph(graph); + }); } };