Revert "Update importer to always populate "tf.entry_function" ..."
PiperOrigin-RevId: 338064350 Change-Id: I972277dd19b061c0d83a1533005f46de0720bb05
This commit is contained in:
parent
d14a44fb49
commit
5fe5a49092
@ -1,4 +1,4 @@
|
|||||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
|
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s
|
||||||
|
|
||||||
node {
|
node {
|
||||||
name: "arg0"
|
name: "arg0"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s
|
// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s
|
||||||
|
|
||||||
// Two islands chained by data-flow contributing to the graph return are
|
// Two islands chained by data-flow contributing to the graph return are
|
||||||
// preserved.
|
// preserved.
|
||||||
@ -18,6 +18,20 @@ func @chained_islands(%arg0 : i32) -> i32 {
|
|||||||
return %0 : i32
|
return %0 : i32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that a function that does not have arguments/results is ignored by
|
||||||
|
// thep pruning pass: this could be a V1 graph imported without feeds/fetches.
|
||||||
|
// CHECK-LABEL: func @empty_islands(
|
||||||
|
func @empty_islands() {
|
||||||
|
// CHECK: tf_executor.island
|
||||||
|
tf_executor.graph {
|
||||||
|
%0 = tf_executor.island {
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check that an unused island that doesn't contribute to the fetch is removed.
|
// Check that an unused island that doesn't contribute to the fetch is removed.
|
||||||
// CHECK-LABEL: func @dead_island(
|
// CHECK-LABEL: func @dead_island(
|
||||||
func @dead_island(%arg0 : i32) -> i32 {
|
func @dead_island(%arg0 : i32) -> i32 {
|
||||||
@ -151,37 +165,3 @@ func @control_fetch(%arg0 : i32) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Check that a function that is named "main" and does not have the
|
|
||||||
// "tf.entry_function" attribute defined is ignored by the pruning pass: this
|
|
||||||
// could be a V1 graph imported without feed/fetch/target nodes.
|
|
||||||
// CHECK-LABEL: func @main(
|
|
||||||
func @main() {
|
|
||||||
// CHECK: tf_executor.island
|
|
||||||
tf_executor.graph {
|
|
||||||
%0 = tf_executor.island {
|
|
||||||
tf_executor.yield
|
|
||||||
}
|
|
||||||
tf_executor.fetch
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Check that a function that is named "main" and does have the
|
|
||||||
// "tf.entry_function" attribute defined with no feed/fetch/target nodes is
|
|
||||||
// pruned.
|
|
||||||
// CHECK-LABEL: func @main(
|
|
||||||
func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} {
|
|
||||||
// CHECK-NOT: tf_executor.island
|
|
||||||
tf_executor.graph {
|
|
||||||
%0 = tf_executor.island {
|
|
||||||
tf_executor.yield
|
|
||||||
}
|
|
||||||
tf_executor.fetch
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/iterator_range.h"
|
#include "llvm/ADT/iterator_range.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
@ -31,18 +30,6 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Checks if a tf_executor.Graph can be pruned.
|
|
||||||
// For TensorFlow V1.0 compatibility: when importing a graph without providing
|
|
||||||
// feeds/fetches/targets we should not attempt to prune. The best approximation
|
|
||||||
// here is to check if the graph is of the "main" function and does not have the
|
|
||||||
// "tf.entry_function" attribute defined.
|
|
||||||
bool CanPruneGraph(FuncOp func) {
|
|
||||||
return func.getName() != "main" ||
|
|
||||||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Visits an op's operand if it is an output of an Operation in the same
|
// Visits an op's operand if it is an output of an Operation in the same
|
||||||
// tf_executor.graph.
|
// tf_executor.graph.
|
||||||
void VisitOpOperand(GraphOp graph, Value operand,
|
void VisitOpOperand(GraphOp graph, Value operand,
|
||||||
@ -88,8 +75,6 @@ void VisitOp(GraphOp graph, Operation* op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||||
void PruneGraph(GraphOp graph) {
|
void PruneGraph(GraphOp graph) {
|
||||||
// A graph has a single block which forms a DAG: operations that aren't
|
// A graph has a single block which forms a DAG: operations that aren't
|
||||||
@ -122,8 +107,15 @@ namespace {
|
|||||||
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
||||||
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
|
struct GraphPruning : public PassWrapper<GraphPruning, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
if (!CanPruneGraph(getFunction())) return;
|
getFunction().walk([](tf_executor::GraphOp graph) {
|
||||||
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
|
// For TensorFlow V1.0 compatibility: when importing a graph without
|
||||||
|
// providing feeds/fetches we should not attempt to prune. The best
|
||||||
|
// approximation here is to check if the graph does not have any fetched
|
||||||
|
// values.
|
||||||
|
if (!graph.GetFetch().getNumOperands()) return;
|
||||||
|
|
||||||
|
PruneGraph(graph);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2126,27 +2126,28 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs,
|
||||||
&control_ret_nodes));
|
&control_ret_nodes));
|
||||||
|
|
||||||
mlir::Builder b(context);
|
if (!arg_nodes.empty() || !ret_nodes.empty() ||
|
||||||
std::string s;
|
!control_ret_nodes.empty()) {
|
||||||
llvm::raw_string_ostream ss(s);
|
mlir::Builder b(context);
|
||||||
auto node_name = [&](const OutputTensor& tensor) {
|
std::string s;
|
||||||
ss << tensor.node->name();
|
llvm::raw_string_ostream ss(s);
|
||||||
};
|
auto node_name = [&](const OutputTensor& tensor) {
|
||||||
llvm::interleave(arg_nodes, ss, node_name, ",");
|
ss << tensor.node->name();
|
||||||
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
};
|
||||||
s.clear();
|
llvm::interleave(arg_nodes, ss, node_name, ",");
|
||||||
llvm::interleave(ret_nodes, ss, node_name, ",");
|
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
||||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
s.clear();
|
||||||
s.clear();
|
llvm::interleave(ret_nodes, ss, node_name, ",");
|
||||||
llvm::interleave(specs.control_outputs, ss, ",");
|
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||||
auto control_outputs =
|
s.clear();
|
||||||
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
|
llvm::interleave(specs.control_outputs, ss, ",");
|
||||||
|
auto control_outputs =
|
||||||
|
b.getNamedAttr("control_outputs", b.getStringAttr(ss.str()));
|
||||||
|
|
||||||
// Under `graph_as_function` mode, `tf.entry_function` is always set as it
|
attrs.push_back(b.getNamedAttr(
|
||||||
// is assumed feed, fetch, and target nodes are set correctly.
|
"tf.entry_function",
|
||||||
attrs.push_back(b.getNamedAttr(
|
b.getDictionaryAttr({inputs, outputs, control_outputs})));
|
||||||
"tf.entry_function",
|
}
|
||||||
b.getDictionaryAttr({inputs, outputs, control_outputs})));
|
|
||||||
} else {
|
} else {
|
||||||
// Collects the argument and return nodes by looking up the node names
|
// Collects the argument and return nodes by looking up the node names
|
||||||
// specified by the user.
|
// specified by the user.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user