From d14a44fb494861f339d34b26bee5ddf495123ae8 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 20 Oct 2020 07:05:40 -0700 Subject: [PATCH] Update importer to always populate "tf.entry_function" attribute when graph_as_function is set and update graph pruning pass to only not prune graphs for functions named "main" and the "tf.entry_function" attribute is not set. Function graphs imported should always be prunable due to explicit feed/fetch/target nodes. If a function graph has no fetch/target nodes, the whole graph should be pruned. This will now only leave special casing of Graphs imported via the v1 manner where feed/fetch/target nodes are unknown. PiperOrigin-RevId: 338053441 Change-Id: Ide746d70512ed525637c7a3f55213eab8393de90 --- .../compile_mlir_util/graph-resource.pbtxt | 2 +- .../mlir/tensorflow/tests/graph_pruning.mlir | 50 +++++++++++++------ .../tensorflow/transforms/graph_pruning.cc | 26 ++++++---- .../mlir/tensorflow/translate/import_model.cc | 41 ++++++++------- 4 files changed, 73 insertions(+), 46 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt index 5fb90b1bce0..b8c779992ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s node { name: "arg0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index c52488b4afc..1f0a183c19e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s +// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s // Two islands chained by data-flow contributing to the graph return are // preserved. @@ -18,20 +18,6 @@ func @chained_islands(%arg0 : i32) -> i32 { return %0 : i32 } -// Check that a function that does not have arguments/results is ignored by -// thep pruning pass: this could be a V1 graph imported without feeds/fetches. -// CHECK-LABEL: func @empty_islands( -func @empty_islands() { -// CHECK: tf_executor.island - tf_executor.graph { - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} - // Check that an unused island that doesn't contribute to the fetch is removed. // CHECK-LABEL: func @dead_island( func @dead_island(%arg0 : i32) -> i32 { @@ -165,3 +151,37 @@ func @control_fetch(%arg0 : i32) { } return } + +// ----- + +// Check that a function that is named "main" and does not have the +// "tf.entry_function" attribute defined is ignored by the pruning pass: this +// could be a V1 graph imported without feed/fetch/target nodes. +// CHECK-LABEL: func @main( +func @main() { +// CHECK: tf_executor.island + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// ----- + +// Check that a function that is named "main" and does have the +// "tf.entry_function" attribute defined with no feed/fetch/target nodes is +// pruned. +// CHECK-LABEL: func @main( +func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} { +// CHECK-NOT: tf_executor.island + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 859d3ffb23c..26c0126932c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -30,6 +31,18 @@ limitations under the License. namespace mlir { namespace tf_executor { +namespace { + +// Checks if a tf_executor.Graph can be pruned. +// For TensorFlow V1.0 compatibility: when importing a graph without providing +// feeds/fetches/targets we should not attempt to prune. The best approximation +// here is to check if the graph is of the "main" function and does not have the +// "tf.entry_function" attribute defined. +bool CanPruneGraph(FuncOp func) { + return func.getName() != "main" || + func.getAttrOfType("tf.entry_function") != nullptr; +} + // Visits an op's operand if it is an output of an Operation in the same // tf_executor.graph. void VisitOpOperand(GraphOp graph, Value operand, @@ -75,6 +88,8 @@ void VisitOp(GraphOp graph, Operation* op, } } +} // namespace + // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph) { // A graph has a single block which forms a DAG: operations that aren't @@ -107,15 +122,8 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public PassWrapper { void runOnFunction() override { - getFunction().walk([](tf_executor::GraphOp graph) { - // For TensorFlow V1.0 compatibility: when importing a graph without - // providing feeds/fetches we should not attempt to prune. The best - // approximation here is to check if the graph does not have any fetched - // values. - if (!graph.GetFetch().getNumOperands()) return; - - PruneGraph(graph); - }); + if (!CanPruneGraph(getFunction())) return; + getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 42ce5c533a2..a315a5523ca 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2126,28 +2126,27 @@ StatusOr GraphDefImporter::Convert( TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, &control_ret_nodes)); - if (!arg_nodes.empty() || !ret_nodes.empty() || - !control_ret_nodes.empty()) { - mlir::Builder b(context); - std::string s; - llvm::raw_string_ostream ss(s); - auto node_name = [&](const OutputTensor& tensor) { - ss << tensor.node->name(); - }; - llvm::interleave(arg_nodes, ss, node_name, ","); - auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(ret_nodes, ss, node_name, ","); - auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(specs.control_outputs, ss, ","); - auto control_outputs = - b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + auto node_name = [&](const OutputTensor& tensor) { + ss << tensor.node->name(); + }; + llvm::interleave(arg_nodes, ss, node_name, ","); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(ret_nodes, ss, node_name, ","); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(specs.control_outputs, ss, ","); + auto control_outputs = + b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); - attrs.push_back(b.getNamedAttr( - "tf.entry_function", - b.getDictionaryAttr({inputs, outputs, control_outputs}))); - } + // Under `graph_as_function` mode, `tf.entry_function` is always set as it + // is assumed feed, fetch, and target nodes are set correctly. + attrs.push_back(b.getNamedAttr( + "tf.entry_function", + b.getDictionaryAttr({inputs, outputs, control_outputs}))); } else { // Collects the argument and return nodes by looking up the node names // specified by the user.