Revert "Update importer to always populate "tf.entry_function" ..."

PiperOrigin-RevId: 338064350
Change-Id: I972277dd19b061c0d83a1533005f46de0720bb05
This commit is contained in:
Mihai Maruseac 2020-10-20 08:30:48 -07:00 committed by TensorFlower Gardener
parent d14a44fb49
commit 5fe5a49092
4 changed files with 46 additions and 73 deletions

View File

@ -1,4 +1,4 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s # RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
node { node {
name: "arg0" name: "arg0"

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s // RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s
// Two islands chained by data-flow contributing to the graph return are // Two islands chained by data-flow contributing to the graph return are
// preserved. // preserved.
@ -18,6 +18,20 @@ func @chained_islands(%arg0 : i32) -> i32 {
return %0 : i32 return %0 : i32
} }
// Check that a function that does not have arguments/results is ignored by
// thep pruning pass: this could be a V1 graph imported without feeds/fetches.
// CHECK-LABEL: func @empty_islands(
func @empty_islands() {
// CHECK: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}
// Check that an unused island that doesn't contribute to the fetch is removed. // Check that an unused island that doesn't contribute to the fetch is removed.
// CHECK-LABEL: func @dead_island( // CHECK-LABEL: func @dead_island(
func @dead_island(%arg0 : i32) -> i32 { func @dead_island(%arg0 : i32) -> i32 {
@ -151,37 +165,3 @@ func @control_fetch(%arg0 : i32) {
} }
return return
} }
// -----
// Check that a function that is named "main" and does not have the
// "tf.entry_function" attribute defined is ignored by the pruning pass: this
// could be a V1 graph imported without feed/fetch/target nodes.
// CHECK-LABEL: func @main(
func @main() {
// CHECK: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}
// -----
// Check that a function that is named "main" and does have the
// "tf.entry_function" attribute defined with no feed/fetch/target nodes is
// pruned.
// CHECK-LABEL: func @main(
func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} {
// CHECK-NOT: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project
@ -31,18 +30,6 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace tf_executor { namespace tf_executor {
namespace {
// Checks if a tf_executor.Graph can be pruned.
// For TensorFlow V1.0 compatibility: when importing a graph without providing
// feeds/fetches/targets we should not attempt to prune. The best approximation
// here is to check if the graph is of the "main" function and does not have the
// "tf.entry_function" attribute defined.
bool CanPruneGraph(FuncOp func) {
return func.getName() != "main" ||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
}
// Visits an op's operand if it is an output of an Operation in the same // Visits an op's operand if it is an output of an Operation in the same
// tf_executor.graph. // tf_executor.graph.
void VisitOpOperand(GraphOp graph, Value operand, void VisitOpOperand(GraphOp graph, Value operand,
@ -88,8 +75,6 @@ void VisitOp(GraphOp graph, Operation* op,
} }
} }
} // namespace
// Prunes unreachable operations of a tf_executor.graph operation. // Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph) { void PruneGraph(GraphOp graph) {
// A graph has a single block which forms a DAG: operations that aren't // A graph has a single block which forms a DAG: operations that aren't
@ -122,8 +107,15 @@ namespace {
// This transformation pass prunes a TF graph eliminating dead-nodes. // This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> { struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
if (!CanPruneGraph(getFunction())) return; getFunction().walk([](tf_executor::GraphOp graph) {
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); // For TensorFlow V1.0 compatibility: when importing a graph without
// providing feeds/fetches we should not attempt to prune. The best
// approximation here is to check if the graph does not have any fetched
// values.
if (!graph.GetFetch().getNumOperands()) return;
PruneGraph(graph);
});
} }
}; };

View File

@ -2126,27 +2126,28 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
&control_ret_nodes)); &control_ret_nodes));
mlir::Builder b(context); if (!arg_nodes.empty() || !ret_nodes.empty() ||
std::string s; !control_ret_nodes.empty()) {
llvm::raw_string_ostream ss(s); mlir::Builder b(context);
auto node_name = [&](const OutputTensor& tensor) { std::string s;
ss << tensor.node->name(); llvm::raw_string_ostream ss(s);
}; auto node_name = [&](const OutputTensor& tensor) {
llvm::interleave(arg_nodes, ss, node_name, ","); ss << tensor.node->name();
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); };
s.clear(); llvm::interleave(arg_nodes, ss, node_name, ",");
llvm::interleave(ret_nodes, ss, node_name, ","); auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); s.clear();
s.clear(); llvm::interleave(ret_nodes, ss, node_name, ",");
llvm::interleave(specs.control_outputs, ss, ","); auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
auto control_outputs = s.clear();
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); llvm::interleave(specs.control_outputs, ss, ",");
auto control_outputs =
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
// Under `graph_as_function` mode, `tf.entry_function` is always set as it attrs.push_back(b.getNamedAttr(
// is assumed feed, fetch, and target nodes are set correctly. "tf.entry_function",
attrs.push_back(b.getNamedAttr( b.getDictionaryAttr({inputs, outputs, control_outputs})));
"tf.entry_function", }
b.getDictionaryAttr({inputs, outputs, control_outputs})));
} else { } else {
// Collects the argument and return nodes by looking up the node names // Collects the argument and return nodes by looking up the node names
// specified by the user. // specified by the user.