diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt index b8c779992ac..5fb90b1bce0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s node { name: "arg0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index 1f0a183c19e..c52488b4afc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s +// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s // Two islands chained by data-flow contributing to the graph return are // preserved. @@ -18,6 +18,20 @@ func @chained_islands(%arg0 : i32) -> i32 { return %0 : i32 } +// Check that a function that does not have arguments/results is ignored by +// thep pruning pass: this could be a V1 graph imported without feeds/fetches. +// CHECK-LABEL: func @empty_islands( +func @empty_islands() { +// CHECK: tf_executor.island + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} + // Check that an unused island that doesn't contribute to the fetch is removed. // CHECK-LABEL: func @dead_island( func @dead_island(%arg0 : i32) -> i32 { @@ -151,37 +165,3 @@ func @control_fetch(%arg0 : i32) { } return } - -// ----- - -// Check that a function that is named "main" and does not have the -// "tf.entry_function" attribute defined is ignored by the pruning pass: this -// could be a V1 graph imported without feed/fetch/target nodes. -// CHECK-LABEL: func @main( -func @main() { -// CHECK: tf_executor.island - tf_executor.graph { - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} - -// ----- - -// Check that a function that is named "main" and does have the -// "tf.entry_function" attribute defined with no feed/fetch/target nodes is -// pruned. -// CHECK-LABEL: func @main( -func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} { -// CHECK-NOT: tf_executor.island - tf_executor.graph { - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 26c0126932c..859d3ffb23c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -18,7 +18,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -31,18 +30,6 @@ limitations under the License. namespace mlir { namespace tf_executor { -namespace { - -// Checks if a tf_executor.Graph can be pruned. -// For TensorFlow V1.0 compatibility: when importing a graph without providing -// feeds/fetches/targets we should not attempt to prune. The best approximation -// here is to check if the graph is of the "main" function and does not have the -// "tf.entry_function" attribute defined. -bool CanPruneGraph(FuncOp func) { - return func.getName() != "main" || - func.getAttrOfType("tf.entry_function") != nullptr; -} - // Visits an op's operand if it is an output of an Operation in the same // tf_executor.graph. void VisitOpOperand(GraphOp graph, Value operand, @@ -88,8 +75,6 @@ void VisitOp(GraphOp graph, Operation* op, } } -} // namespace - // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph) { // A graph has a single block which forms a DAG: operations that aren't @@ -122,8 +107,15 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public PassWrapper { void runOnFunction() override { - if (!CanPruneGraph(getFunction())) return; - getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); + getFunction().walk([](tf_executor::GraphOp graph) { + // For TensorFlow V1.0 compatibility: when importing a graph without + // providing feeds/fetches we should not attempt to prune. The best + // approximation here is to check if the graph does not have any fetched + // values. + if (!graph.GetFetch().getNumOperands()) return; + + PruneGraph(graph); + }); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a315a5523ca..42ce5c533a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2126,27 +2126,28 @@ StatusOr GraphDefImporter::Convert( TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, &control_ret_nodes)); - mlir::Builder b(context); - std::string s; - llvm::raw_string_ostream ss(s); - auto node_name = [&](const OutputTensor& tensor) { - ss << tensor.node->name(); - }; - llvm::interleave(arg_nodes, ss, node_name, ","); - auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(ret_nodes, ss, node_name, ","); - auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(specs.control_outputs, ss, ","); - auto control_outputs = - b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); + if (!arg_nodes.empty() || !ret_nodes.empty() || + !control_ret_nodes.empty()) { + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + auto node_name = [&](const OutputTensor& tensor) { + ss << tensor.node->name(); + }; + llvm::interleave(arg_nodes, ss, node_name, ","); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(ret_nodes, ss, node_name, ","); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(specs.control_outputs, ss, ","); + auto control_outputs = + b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); - // Under `graph_as_function` mode, `tf.entry_function` is always set as it - // is assumed feed, fetch, and target nodes are set correctly. - attrs.push_back(b.getNamedAttr( - "tf.entry_function", - b.getDictionaryAttr({inputs, outputs, control_outputs}))); + attrs.push_back(b.getNamedAttr( + "tf.entry_function", + b.getDictionaryAttr({inputs, outputs, control_outputs}))); + } } else { // Collects the argument and return nodes by looking up the node names // specified by the user.