Preserve identity nodes on the receiver side of a cross device link instead of

preserving them when they're on the sender side.

PiperOrigin-RevId: 162698209
This commit is contained in:
Benoit Steiner 2017-07-20 19:38:40 -07:00 committed by TensorFlower Gardener
parent 386f4aef0d
commit a9e5591dbf
4 changed files with 56 additions and 8 deletions

View File

@ -69,6 +69,10 @@ bool GraphRewriter::IsConnectedToFunction(const NodeDef& node) const {
return function_neighbors_.find(&node) != function_neighbors_.end(); return function_neighbors_.find(&node) != function_neighbors_.end();
} }
bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
}
void GraphRewriter::RecordConnectivity( void GraphRewriter::RecordConnectivity(
const NodeDef& node, const std::unordered_set<string>& function_names) { const NodeDef& node, const std::unordered_set<string>& function_names) {
const bool is_function = const bool is_function =
@ -94,6 +98,9 @@ void GraphRewriter::RecordConnectivity(
function_neighbors_.insert(fanin); function_neighbors_.insert(fanin);
} }
} }
if (fanin->device() != node.device()) {
cross_device_receivers_.insert(&node);
}
} }
} }
@ -119,9 +126,7 @@ void GraphRewriter::ForwardInputsInternal(
continue; continue;
} }
const NodeDef* input_node = itr->second; const NodeDef* input_node = itr->second;
if ((input_node->device().empty() || node.device().empty() || if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
input_node->device() == node.device()) &&
nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
ForwardInputsInternal(*input_node, nodes_to_delete, new_node); ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
} else { } else {
*new_node->add_input() = input; *new_node->add_input() = input;

View File

@ -51,6 +51,10 @@ class GraphRewriter {
// fanout (excluding control dependencies) of 'node' is a function. // fanout (excluding control dependencies) of 'node' is a function.
bool IsConnectedToFunction(const NodeDef& node) const; bool IsConnectedToFunction(const NodeDef& node) const;
// Returns true if the node is driven by at least one node placed on another
// device.
bool IsDrivenByAnotherDevice(const NodeDef& node) const;
private: private:
void RecordConnectivity(const NodeDef& node, void RecordConnectivity(const NodeDef& node,
const std::unordered_set<string>& function_names); const std::unordered_set<string>& function_names);
@ -63,6 +67,7 @@ class GraphRewriter {
std::unordered_map<string, const NodeDef*> optimized_nodes_; std::unordered_map<string, const NodeDef*> optimized_nodes_;
std::unordered_set<const NodeDef*> control_dependency_drivers_; std::unordered_set<const NodeDef*> control_dependency_drivers_;
std::unordered_set<const NodeDef*> function_neighbors_; std::unordered_set<const NodeDef*> function_neighbors_;
std::unordered_set<const NodeDef*> cross_device_receivers_;
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -51,6 +51,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
continue; continue;
} }
// Don't remove nodes that drive control dependencies. // Don't remove nodes that drive control dependencies.
// Don't remove nodes that are driven by control dependencies either since // Don't remove nodes that are driven by control dependencies either since
// we can't ensure (yet) that we won't increase the number of control // we can't ensure (yet) that we won't increase the number of control
@ -59,9 +60,12 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
// creation of 100 edges). // creation of 100 edges).
// Don't modify nodes that are connected to functions since that can result // Don't modify nodes that are connected to functions since that can result
// in inlining failures later on. // in inlining failures later on.
// Don't prune nodes that are driven by another device since these could be
// used to reduce cross device communication.
if (!rewriter.DrivesControlDependency(node) && if (!rewriter.DrivesControlDependency(node) &&
!rewriter.IsDrivenByControlDependency(node) && !rewriter.IsDrivenByControlDependency(node) &&
!rewriter.IsConnectedToFunction(node)) { !rewriter.IsConnectedToFunction(node) &&
!rewriter.IsDrivenByAnotherDevice(node)) {
nodes_to_delete.insert(&node); nodes_to_delete.insert(&node);
} }
} }

View File

@ -94,8 +94,8 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a}); Output b = ops::AddN(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c").WithDevice("CPU:0"), b); Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d").WithDevice("GPU:0"), c); Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::AddN(s.WithOpName("e"), {d}); Output e = ops::AddN(s.WithOpName("e"), {d});
GrapplerItem item; GrapplerItem item;
@ -119,9 +119,11 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(NodeName(e.name()), new_e.name());
EXPECT_EQ(1, new_e.input_size()); EXPECT_EQ(1, new_e.input_size());
EXPECT_EQ(NodeName(c.name()), new_e.input(0)); EXPECT_EQ(NodeName(b.name()), new_e.input(0));
EXPECT_EQ(1, new_d.input_size()); EXPECT_EQ(1, new_d.input_size());
EXPECT_EQ(NodeName(c.name()), new_d.input(0)); EXPECT_EQ(NodeName(b.name()), new_d.input(0));
EXPECT_EQ(1, new_c.input_size());
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
} }
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
@ -235,6 +237,38 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) {
EXPECT_EQ(NodeName(c.name()), new_c.name()); EXPECT_EQ(NodeName(c.name()), new_c.name());
} }
TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c").WithDevice("/cpu:0"), 0.0f, {10, 10});
// Node i1 should be preserved.
Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c);
Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
// Node i2 should be pruned since it resides on the sender's device.
Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c);
Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"a1", "a2", "a3", "a4"};
ModelPruner pruner;
GraphDef output;
Status status = pruner.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
for (const auto& node : output.node()) {
if (node.name() == "a1" || node.name() == "a2") {
EXPECT_EQ("i1", node.input(0));
} else if (node.name() == "a3" || node.name() == "a4") {
EXPECT_EQ("c", node.input(0));
}
}
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow