Make tf-executor-graph-pruning more conservative to support TFV1 graphs

This is adding a heuristic to ignore function with no arguments/returns as they could
represent TFV1 graphs imported without feeds/fetches.

PiperOrigin-RevId: 295251291
Change-Id: I6b08b4327b85bc0ae6a36e914683984af85faf53
This commit is contained in:
Mehdi Amini 2020-02-14 16:23:25 -08:00 committed by TensorFlower Gardener
parent c282d1a905
commit 33f00e722e
2 changed files with 14 additions and 8 deletions

View File

@ -18,17 +18,15 @@ func @chained_islands(%arg0 : i32) -> i32 {
return %0 : i32
}
// Check that empty islands that don't contribute to the fetch are removed.
// Check that a function that does not have arguments/results is ignored by
// thep pruning pass: this could be a V1 graph imported without feeds/fetches.
// CHECK-LABEL: func @empty_islands(
func @empty_islands() {
// CHECK-NOT: tf_executor.island
// CHECK: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
%1 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
@ -87,7 +85,7 @@ func @nextiteration_deleted(%arg0 : i32) -> i32 {
// Check that NextIteration.source/sink ops and associated ops are deleted when
// associated loop is unreachable.
// CHECK-LABEL: func @unreachable_loop
func @unreachable_loop() {
func @unreachable_loop(%arg0 : i32) {
// CHECK: tf_executor.graph
// CHECK-NEXT: tf_executor.fetch
tf_executor.graph {
@ -104,7 +102,7 @@ func @unreachable_loop() {
%10:2 = tf_executor.island(%9#1) wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> tensor<i32>
%11:2 = tf_executor.island wraps "tf.Add"(%9#0, %10#0) {T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
tf_executor.NextIteration.Sink [%0#1] %11#0 : tensor<*xi32> {T = "tfdtype$DT_INT32"}
tf_executor.fetch
tf_executor.fetch %arg0 : i32
}
return
}

View File

@ -86,7 +86,15 @@ namespace {
// This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public FunctionPass<GraphPruning> {
void runOnFunction() override {
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
getFunction().walk([](tf_executor::GraphOp graph) {
// For TensorFlow V1.0 compatibility: when importing a graph without
// providing feeds/fetches we should not attempt to prune. The best
// approximation here is to check if the graph does not have any fetched
// values.
if (!graph.GetFetch().getNumOperands()) return;
PruneGraph(graph);
});
}
};