Revert "Update importer to always populate "tf.entry_function" ..."

PiperOrigin-RevId: 338064350
Change-Id: I972277dd19b061c0d83a1533005f46de0720bb05
This commit is contained in:
Mihai Maruseac 2020-10-20 08:30:48 -07:00 committed by TensorFlower Gardener
parent d14a44fb49
commit 5fe5a49092
4 changed files with 46 additions and 73 deletions

View File

@ -1,4 +1,4 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
node {
name: "arg0"

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s
// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s
// Two islands chained by data-flow contributing to the graph return are
// preserved.
@ -18,6 +18,20 @@ func @chained_islands(%arg0 : i32) -> i32 {
return %0 : i32
}
// Check that a function that does not have arguments/results is ignored by
// thep pruning pass: this could be a V1 graph imported without feeds/fetches.
// CHECK-LABEL: func @empty_islands(
func @empty_islands() {
// CHECK: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}
// Check that an unused island that doesn't contribute to the fetch is removed.
// CHECK-LABEL: func @dead_island(
func @dead_island(%arg0 : i32) -> i32 {
@ -151,37 +165,3 @@ func @control_fetch(%arg0 : i32) {
}
return
}
// -----
// Check that a function that is named "main" and does not have the
// "tf.entry_function" attribute defined is ignored by the pruning pass: this
// could be a V1 graph imported without feed/fetch/target nodes.
// CHECK-LABEL: func @main(
func @main() {
// CHECK: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}
// -----
// Check that a function that is named "main" and does have the
// "tf.entry_function" attribute defined with no feed/fetch/target nodes is
// pruned.
// CHECK-LABEL: func @main(
func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} {
// CHECK-NOT: tf_executor.island
tf_executor.graph {
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
@ -31,18 +30,6 @@ limitations under the License.
namespace mlir {
namespace tf_executor {
namespace {
// Checks if a tf_executor.Graph can be pruned.
// For TensorFlow V1.0 compatibility: when importing a graph without providing
// feeds/fetches/targets we should not attempt to prune. The best approximation
// here is to check if the graph is of the "main" function and does not have the
// "tf.entry_function" attribute defined.
bool CanPruneGraph(FuncOp func) {
return func.getName() != "main" ||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
}
// Visits an op's operand if it is an output of an Operation in the same
// tf_executor.graph.
void VisitOpOperand(GraphOp graph, Value operand,
@ -88,8 +75,6 @@ void VisitOp(GraphOp graph, Operation* op,
}
}
} // namespace
// Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph) {
// A graph has a single block which forms a DAG: operations that aren't
@ -122,8 +107,15 @@ namespace {
// This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
void runOnFunction() override {
if (!CanPruneGraph(getFunction())) return;
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
getFunction().walk([](tf_executor::GraphOp graph) {
// For TensorFlow V1.0 compatibility: when importing a graph without
// providing feeds/fetches we should not attempt to prune. The best
// approximation here is to check if the graph does not have any fetched
// values.
if (!graph.GetFetch().getNumOperands()) return;
PruneGraph(graph);
});
}
};

View File

@ -2126,27 +2126,28 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
&control_ret_nodes));
mlir::Builder b(context);
std::string s;
llvm::raw_string_ostream ss(s);
auto node_name = [&](const OutputTensor& tensor) {
ss << tensor.node->name();
};
llvm::interleave(arg_nodes, ss, node_name, ",");
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
s.clear();
llvm::interleave(ret_nodes, ss, node_name, ",");
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
s.clear();
llvm::interleave(specs.control_outputs, ss, ",");
auto control_outputs =
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
if (!arg_nodes.empty() || !ret_nodes.empty() ||
!control_ret_nodes.empty()) {
mlir::Builder b(context);
std::string s;
llvm::raw_string_ostream ss(s);
auto node_name = [&](const OutputTensor& tensor) {
ss << tensor.node->name();
};
llvm::interleave(arg_nodes, ss, node_name, ",");
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
s.clear();
llvm::interleave(ret_nodes, ss, node_name, ",");
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
s.clear();
llvm::interleave(specs.control_outputs, ss, ",");
auto control_outputs =
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
// Under `graph_as_function` mode, `tf.entry_function` is always set as it
// is assumed feed, fetch, and target nodes are set correctly.
attrs.push_back(b.getNamedAttr(
"tf.entry_function",
b.getDictionaryAttr({inputs, outputs, control_outputs})));
attrs.push_back(b.getNamedAttr(
"tf.entry_function",
b.getDictionaryAttr({inputs, outputs, control_outputs})));
}
} else {
// Collects the argument and return nodes by looking up the node names
// specified by the user.