Prune redundant control inputs early in model_pruner, since they may prevent deletion of trivial nodes.
Prune NoOp nodes with empty fanout. PiperOrigin-RevId: 312514074 Change-Id: I22cb76f5b9b152fc51ce34918d28a81f929ffa38
This commit is contained in:
parent
9ef6f66ce1
commit
502e75c139
tensorflow/core/grappler/optimizers
@ -33,6 +33,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
bool IsTrivialIdentity(const NodeDef& node, const GraphView& graph_view) {
|
||||
for (const auto input :
|
||||
@ -103,7 +104,9 @@ bool IsOutputPortRefValue(const NodeDef& node, int port_id,
|
||||
bool CanRemoveNode(const NodeDef& node, const GraphView& graph_view,
|
||||
const absl::flat_hash_set<string>& function_names,
|
||||
const OpRegistryInterface& op_registry) {
|
||||
if (IsNoOp(node) && node.input().empty()) {
|
||||
if (IsNoOp(node) &&
|
||||
(node.input().empty() ||
|
||||
graph_view.NumFanouts(node, /*include_controlled_nodes=*/true) == 0)) {
|
||||
return true;
|
||||
}
|
||||
if (IsConstant(node) && node.input().empty() &&
|
||||
@ -412,6 +415,8 @@ Status SplitIdentityNInputs(GraphDef* graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
|
||||
@ -453,13 +458,18 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// Check if we can further prune the graph, by removing the trivial ops.
|
||||
absl::flat_hash_set<const NodeDef*> nodes_to_delete;
|
||||
for (const auto& node : pruned_graph->node()) {
|
||||
if (!IsTrivialOp(node, graph_view)) {
|
||||
for (int i = 0; i < pruned_graph->node_size(); ++i) {
|
||||
NodeDef* node = pruned_graph->mutable_node(i);
|
||||
// Remove redundant control inputs, since they may prevent pruning below.
|
||||
DedupControlInputs(node);
|
||||
|
||||
if (!IsTrivialOp(*node, graph_view)) {
|
||||
VLOG(3) << node->name() << " is not trivial.";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't remove nodes that must be preserved.
|
||||
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
|
||||
if (nodes_to_preserve.find(node->name()) != nodes_to_preserve.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -477,8 +487,10 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// converting references to non-references. It is important to preserve
|
||||
// these non-references since the partitioner will avoid sending
|
||||
// non-references across partitions more than once.
|
||||
if (CanRemoveNode(node, graph_view, function_names, *op_registry)) {
|
||||
nodes_to_delete.insert(&node);
|
||||
if (CanRemoveNode(*node, graph_view, function_names, *op_registry)) {
|
||||
nodes_to_delete.insert(node);
|
||||
} else {
|
||||
VLOG(3) << node->name() << " cannot be removed";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,12 +100,13 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
Output c = ops::Identity(s.WithOpName("c").WithControlDependencies(b), b);
|
||||
Output d = ops::Identity(s.WithOpName("d"), c);
|
||||
Output e = ops::Sqrt(s.WithOpName("e"), {d});
|
||||
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
}
|
||||
item.fetch.push_back("e");
|
||||
|
||||
ModelPruner pruner;
|
||||
GraphDef output;
|
||||
@ -117,8 +118,6 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::Sqrt(s.WithOpName("b"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
Output d = ops::Identity(s.WithOpName("d"), b);
|
||||
Output e = ops::Sqrt(s.WithOpName("e"), {b});
|
||||
|
||||
TF_ASSERT_OK(s.ToGraphDef(&expected));
|
||||
@ -126,10 +125,9 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
||||
|
||||
CompareGraphs(expected, output);
|
||||
|
||||
std::vector<string> fetch = {"e"};
|
||||
auto actual_tensors = EvaluateNodes(output, fetch);
|
||||
auto actual_tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(actual_tensors.size(), 1);
|
||||
auto expected_tensors = EvaluateNodes(item.graph, fetch);
|
||||
auto expected_tensors = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(expected_tensors.size(), 1);
|
||||
test::ExpectTensorEqual<float>(actual_tensors[0], expected_tensors[0]);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user