Prune redundant control inputs early in model_pruner, since they may prevent deletion of trivial nodes.

Prune NoOp nodes with empty fanout.

PiperOrigin-RevId: 312514074
Change-Id: I22cb76f5b9b152fc51ce34918d28a81f929ffa38
This commit is contained in:
A. Unique TensorFlower 2020-05-20 10:55:43 -07:00 committed by TensorFlower Gardener
parent 9ef6f66ce1
commit 502e75c139
2 changed files with 22 additions and 12 deletions
tensorflow/core/grappler/optimizers

View File

@ -33,6 +33,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
bool IsTrivialIdentity(const NodeDef& node, const GraphView& graph_view) {
for (const auto input :
@ -103,7 +104,9 @@ bool IsOutputPortRefValue(const NodeDef& node, int port_id,
bool CanRemoveNode(const NodeDef& node, const GraphView& graph_view,
const absl::flat_hash_set<string>& function_names,
const OpRegistryInterface& op_registry) {
if (IsNoOp(node) && node.input().empty()) {
if (IsNoOp(node) &&
(node.input().empty() ||
graph_view.NumFanouts(node, /*include_controlled_nodes=*/true) == 0)) {
return true;
}
if (IsConstant(node) && node.input().empty() &&
@ -412,6 +415,8 @@ Status SplitIdentityNInputs(GraphDef* graph,
return Status::OK();
}
} // namespace
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
@ -453,13 +458,18 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
// Check if we can further prune the graph, by removing the trivial ops.
absl::flat_hash_set<const NodeDef*> nodes_to_delete;
for (const auto& node : pruned_graph->node()) {
if (!IsTrivialOp(node, graph_view)) {
for (int i = 0; i < pruned_graph->node_size(); ++i) {
NodeDef* node = pruned_graph->mutable_node(i);
// Remove redundant control inputs, since they may prevent pruning below.
DedupControlInputs(node);
if (!IsTrivialOp(*node, graph_view)) {
VLOG(3) << node->name() << " is not trivial.";
continue;
}
// Don't remove nodes that must be preserved.
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
if (nodes_to_preserve.find(node->name()) != nodes_to_preserve.end()) {
continue;
}
@ -477,8 +487,10 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
// converting references to non-references. It is important to preserve
// these non-references since the partitioner will avoid sending
// non-references across partitions more than once.
if (CanRemoveNode(node, graph_view, function_names, *op_registry)) {
nodes_to_delete.insert(&node);
if (CanRemoveNode(*node, graph_view, function_names, *op_registry)) {
nodes_to_delete.insert(node);
} else {
VLOG(3) << node->name() << " cannot be removed";
}
}

View File

@ -100,12 +100,13 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output c = ops::Identity(s.WithOpName("c").WithControlDependencies(b), b);
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::Sqrt(s.WithOpName("e"), {d});
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
}
item.fetch.push_back("e");
ModelPruner pruner;
GraphDef output;
@ -117,8 +118,6 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), b);
Output e = ops::Sqrt(s.WithOpName("e"), {b});
TF_ASSERT_OK(s.ToGraphDef(&expected));
@ -126,10 +125,9 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
CompareGraphs(expected, output);
std::vector<string> fetch = {"e"};
auto actual_tensors = EvaluateNodes(output, fetch);
auto actual_tensors = EvaluateNodes(output, item.fetch);
ASSERT_EQ(actual_tensors.size(), 1);
auto expected_tensors = EvaluateNodes(item.graph, fetch);
auto expected_tensors = EvaluateNodes(item.graph, item.fetch);
ASSERT_EQ(expected_tensors.size(), 1);
test::ExpectTensorEqual<float>(actual_tensors[0], expected_tensors[0]);
}