From a9e5591dbf3ee499f1df6ba0770600ef76d8d27b Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 20 Jul 2017 19:38:40 -0700 Subject: [PATCH] Preserve identity nodes on the receiver side of a cross device link instead of preserving them when they're on the sender side. PiperOrigin-RevId: 162698209 --- .../grappler/optimizers/graph_rewriter.cc | 11 +++-- .../core/grappler/optimizers/graph_rewriter.h | 5 +++ .../core/grappler/optimizers/model_pruner.cc | 6 ++- .../grappler/optimizers/model_pruner_test.cc | 42 +++++++++++++++++-- 4 files changed, 56 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc index 2eba753fa4f..5273f11ca03 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc @@ -69,6 +69,10 @@ bool GraphRewriter::IsConnectedToFunction(const NodeDef& node) const { return function_neighbors_.find(&node) != function_neighbors_.end(); } +bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const { + return cross_device_receivers_.find(&node) != cross_device_receivers_.end(); +} + void GraphRewriter::RecordConnectivity( const NodeDef& node, const std::unordered_set& function_names) { const bool is_function = @@ -94,6 +98,9 @@ void GraphRewriter::RecordConnectivity( function_neighbors_.insert(fanin); } } + if (fanin->device() != node.device()) { + cross_device_receivers_.insert(&node); + } } } @@ -119,9 +126,7 @@ void GraphRewriter::ForwardInputsInternal( continue; } const NodeDef* input_node = itr->second; - if ((input_node->device().empty() || node.device().empty() || - input_node->device() == node.device()) && - nodes_to_delete.find(input_node) != nodes_to_delete.end()) { + if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) { ForwardInputsInternal(*input_node, nodes_to_delete, new_node); } else { *new_node->add_input() = input; diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h index 9a61d768b2f..4bdb063d586 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.h +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h @@ -51,6 +51,10 @@ class GraphRewriter { // fanout (excluding control dependencies) of 'node' is a function. bool IsConnectedToFunction(const NodeDef& node) const; + // Returns true if the node is driven by at least one node placed on another + // device. + bool IsDrivenByAnotherDevice(const NodeDef& node) const; + private: void RecordConnectivity(const NodeDef& node, const std::unordered_set& function_names); @@ -63,6 +67,7 @@ class GraphRewriter { std::unordered_map optimized_nodes_; std::unordered_set control_dependency_drivers_; std::unordered_set function_neighbors_; + std::unordered_set cross_device_receivers_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index 5abada3d6a8..676b41e53dd 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -51,6 +51,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { continue; } + // Don't remove nodes that drive control dependencies. // Don't remove nodes that are driven by control dependencies either since // we can't ensure (yet) that we won't increase the number of control @@ -59,9 +60,12 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // creation of 100 edges). // Don't modify nodes that are connected to functions since that can result // in inlining failures later on. + // Don't prune nodes that are driven by another device since these could be + // used to reduce cross device communication. if (!rewriter.DrivesControlDependency(node) && !rewriter.IsDrivenByControlDependency(node) && - !rewriter.IsConnectedToFunction(node)) { + !rewriter.IsConnectedToFunction(node) && + !rewriter.IsDrivenByAnotherDevice(node)) { nodes_to_delete.insert(&node); } } diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index 3118f3a3602..fdfb3f41cf1 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -94,8 +94,8 @@ TEST_F(ModelPrunerTest, IdentityPruning) { Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::AddN(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c").WithDevice("CPU:0"), b); - Output d = ops::Identity(s.WithOpName("d").WithDevice("GPU:0"), c); + Output c = ops::Identity(s.WithOpName("c"), b); + Output d = ops::Identity(s.WithOpName("d"), c); Output e = ops::AddN(s.WithOpName("e"), {d}); GrapplerItem item; @@ -119,9 +119,11 @@ TEST_F(ModelPrunerTest, IdentityPruning) { EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(1, new_e.input_size()); - EXPECT_EQ(NodeName(c.name()), new_e.input(0)); + EXPECT_EQ(NodeName(b.name()), new_e.input(0)); EXPECT_EQ(1, new_d.input_size()); - EXPECT_EQ(NodeName(c.name()), new_d.input(0)); + EXPECT_EQ(NodeName(b.name()), new_d.input(0)); + EXPECT_EQ(1, new_c.input_size()); + EXPECT_EQ(NodeName(b.name()), new_c.input(0)); } TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { @@ -235,6 +237,38 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) { EXPECT_EQ(NodeName(c.name()), new_c.name()); } +TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c = ops::Const(s.WithOpName("c").WithDevice("/cpu:0"), 0.0f, {10, 10}); + + // Node i1 should be preserved. + Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c); + Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1}); + Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1}); + + // Node i2 should be pruned since it resides on the sender's device. + Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c); + Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2}); + Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"a1", "a2", "a3", "a4"}; + + ModelPruner pruner; + GraphDef output; + Status status = pruner.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + for (const auto& node : output.node()) { + if (node.name() == "a1" || node.name() == "a2") { + EXPECT_EQ("i1", node.input(0)); + } else if (node.name() == "a3" || node.name() == "a4") { + EXPECT_EQ("c", node.input(0)); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow