Revert "Update importer to always populate "tf.entry_function" ..."
PiperOrigin-RevId: 338064350 Change-Id: I972277dd19b061c0d83a1533005f46de0720bb05
This commit is contained in:
parent
d14a44fb49
commit
5fe5a49092
@ -1,4 +1,4 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "arg0"
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s
|
||||
// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s
|
||||
|
||||
// Two islands chained by data-flow contributing to the graph return are
|
||||
// preserved.
|
||||
@ -18,6 +18,20 @@ func @chained_islands(%arg0 : i32) -> i32 {
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// Check that a function that does not have arguments/results is ignored by
|
||||
// thep pruning pass: this could be a V1 graph imported without feeds/fetches.
|
||||
// CHECK-LABEL: func @empty_islands(
|
||||
func @empty_islands() {
|
||||
// CHECK: tf_executor.island
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that an unused island that doesn't contribute to the fetch is removed.
|
||||
// CHECK-LABEL: func @dead_island(
|
||||
func @dead_island(%arg0 : i32) -> i32 {
|
||||
@ -151,37 +165,3 @@ func @control_fetch(%arg0 : i32) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a function that is named "main" and does not have the
|
||||
// "tf.entry_function" attribute defined is ignored by the pruning pass: this
|
||||
// could be a V1 graph imported without feed/fetch/target nodes.
|
||||
// CHECK-LABEL: func @main(
|
||||
func @main() {
|
||||
// CHECK: tf_executor.island
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a function that is named "main" and does have the
|
||||
// "tf.entry_function" attribute defined with no feed/fetch/target nodes is
|
||||
// pruned.
|
||||
// CHECK-LABEL: func @main(
|
||||
func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} {
|
||||
// CHECK-NOT: tf_executor.island
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
@ -31,18 +30,6 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace tf_executor {
|
||||
|
||||
namespace {
|
||||
|
||||
// Checks if a tf_executor.Graph can be pruned.
|
||||
// For TensorFlow V1.0 compatibility: when importing a graph without providing
|
||||
// feeds/fetches/targets we should not attempt to prune. The best approximation
|
||||
// here is to check if the graph is of the "main" function and does not have the
|
||||
// "tf.entry_function" attribute defined.
|
||||
bool CanPruneGraph(FuncOp func) {
|
||||
return func.getName() != "main" ||
|
||||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
|
||||
}
|
||||
|
||||
// Visits an op's operand if it is an output of an Operation in the same
|
||||
// tf_executor.graph.
|
||||
void VisitOpOperand(GraphOp graph, Value operand,
|
||||
@ -88,8 +75,6 @@ void VisitOp(GraphOp graph, Operation* op,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||
void PruneGraph(GraphOp graph) {
|
||||
// A graph has a single block which forms a DAG: operations that aren't
|
||||
@ -122,8 +107,15 @@ namespace {
|
||||
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
||||
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
if (!CanPruneGraph(getFunction())) return;
|
||||
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
|
||||
getFunction().walk([](tf_executor::GraphOp graph) {
|
||||
// For TensorFlow V1.0 compatibility: when importing a graph without
|
||||
// providing feeds/fetches we should not attempt to prune. The best
|
||||
// approximation here is to check if the graph does not have any fetched
|
||||
// values.
|
||||
if (!graph.GetFetch().getNumOperands()) return;
|
||||
|
||||
PruneGraph(graph);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2126,27 +2126,28 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
||||
&control_ret_nodes));
|
||||
|
||||
mlir::Builder b(context);
|
||||
std::string s;
|
||||
llvm::raw_string_ostream ss(s);
|
||||
auto node_name = [&](const OutputTensor& tensor) {
|
||||
ss << tensor.node->name();
|
||||
};
|
||||
llvm::interleave(arg_nodes, ss, node_name, ",");
|
||||
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
||||
s.clear();
|
||||
llvm::interleave(ret_nodes, ss, node_name, ",");
|
||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||
s.clear();
|
||||
llvm::interleave(specs.control_outputs, ss, ",");
|
||||
auto control_outputs =
|
||||
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
|
||||
if (!arg_nodes.empty() || !ret_nodes.empty() ||
|
||||
!control_ret_nodes.empty()) {
|
||||
mlir::Builder b(context);
|
||||
std::string s;
|
||||
llvm::raw_string_ostream ss(s);
|
||||
auto node_name = [&](const OutputTensor& tensor) {
|
||||
ss << tensor.node->name();
|
||||
};
|
||||
llvm::interleave(arg_nodes, ss, node_name, ",");
|
||||
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
||||
s.clear();
|
||||
llvm::interleave(ret_nodes, ss, node_name, ",");
|
||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||
s.clear();
|
||||
llvm::interleave(specs.control_outputs, ss, ",");
|
||||
auto control_outputs =
|
||||
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
|
||||
|
||||
// Under `graph_as_function` mode, `tf.entry_function` is always set as it
|
||||
// is assumed feed, fetch, and target nodes are set correctly.
|
||||
attrs.push_back(b.getNamedAttr(
|
||||
"tf.entry_function",
|
||||
b.getDictionaryAttr({inputs, outputs, control_outputs})));
|
||||
attrs.push_back(b.getNamedAttr(
|
||||
"tf.entry_function",
|
||||
b.getDictionaryAttr({inputs, outputs, control_outputs})));
|
||||
}
|
||||
} else {
|
||||
// Collects the argument and return nodes by looking up the node names
|
||||
// specified by the user.
|
||||
|
Loading…
x
Reference in New Issue
Block a user