Preserve identity nodes on the receiver side of a cross device link instead of
preserving them when they're on the sender side. PiperOrigin-RevId: 162698209
This commit is contained in:
parent
386f4aef0d
commit
a9e5591dbf
@ -69,6 +69,10 @@ bool GraphRewriter::IsConnectedToFunction(const NodeDef& node) const {
|
|||||||
return function_neighbors_.find(&node) != function_neighbors_.end();
|
return function_neighbors_.find(&node) != function_neighbors_.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
|
||||||
|
return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
|
||||||
|
}
|
||||||
|
|
||||||
void GraphRewriter::RecordConnectivity(
|
void GraphRewriter::RecordConnectivity(
|
||||||
const NodeDef& node, const std::unordered_set<string>& function_names) {
|
const NodeDef& node, const std::unordered_set<string>& function_names) {
|
||||||
const bool is_function =
|
const bool is_function =
|
||||||
@ -94,6 +98,9 @@ void GraphRewriter::RecordConnectivity(
|
|||||||
function_neighbors_.insert(fanin);
|
function_neighbors_.insert(fanin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (fanin->device() != node.device()) {
|
||||||
|
cross_device_receivers_.insert(&node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,9 +126,7 @@ void GraphRewriter::ForwardInputsInternal(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const NodeDef* input_node = itr->second;
|
const NodeDef* input_node = itr->second;
|
||||||
if ((input_node->device().empty() || node.device().empty() ||
|
if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
|
||||||
input_node->device() == node.device()) &&
|
|
||||||
nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
|
|
||||||
ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
|
ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
|
||||||
} else {
|
} else {
|
||||||
*new_node->add_input() = input;
|
*new_node->add_input() = input;
|
||||||
|
@ -51,6 +51,10 @@ class GraphRewriter {
|
|||||||
// fanout (excluding control dependencies) of 'node' is a function.
|
// fanout (excluding control dependencies) of 'node' is a function.
|
||||||
bool IsConnectedToFunction(const NodeDef& node) const;
|
bool IsConnectedToFunction(const NodeDef& node) const;
|
||||||
|
|
||||||
|
// Returns true if the node is driven by at least one node placed on another
|
||||||
|
// device.
|
||||||
|
bool IsDrivenByAnotherDevice(const NodeDef& node) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void RecordConnectivity(const NodeDef& node,
|
void RecordConnectivity(const NodeDef& node,
|
||||||
const std::unordered_set<string>& function_names);
|
const std::unordered_set<string>& function_names);
|
||||||
@ -63,6 +67,7 @@ class GraphRewriter {
|
|||||||
std::unordered_map<string, const NodeDef*> optimized_nodes_;
|
std::unordered_map<string, const NodeDef*> optimized_nodes_;
|
||||||
std::unordered_set<const NodeDef*> control_dependency_drivers_;
|
std::unordered_set<const NodeDef*> control_dependency_drivers_;
|
||||||
std::unordered_set<const NodeDef*> function_neighbors_;
|
std::unordered_set<const NodeDef*> function_neighbors_;
|
||||||
|
std::unordered_set<const NodeDef*> cross_device_receivers_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
|
@ -51,6 +51,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
|
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't remove nodes that drive control dependencies.
|
// Don't remove nodes that drive control dependencies.
|
||||||
// Don't remove nodes that are driven by control dependencies either since
|
// Don't remove nodes that are driven by control dependencies either since
|
||||||
// we can't ensure (yet) that we won't increase the number of control
|
// we can't ensure (yet) that we won't increase the number of control
|
||||||
@ -59,9 +60,12 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// creation of 100 edges).
|
// creation of 100 edges).
|
||||||
// Don't modify nodes that are connected to functions since that can result
|
// Don't modify nodes that are connected to functions since that can result
|
||||||
// in inlining failures later on.
|
// in inlining failures later on.
|
||||||
|
// Don't prune nodes that are driven by another device since these could be
|
||||||
|
// used to reduce cross device communication.
|
||||||
if (!rewriter.DrivesControlDependency(node) &&
|
if (!rewriter.DrivesControlDependency(node) &&
|
||||||
!rewriter.IsDrivenByControlDependency(node) &&
|
!rewriter.IsDrivenByControlDependency(node) &&
|
||||||
!rewriter.IsConnectedToFunction(node)) {
|
!rewriter.IsConnectedToFunction(node) &&
|
||||||
|
!rewriter.IsDrivenByAnotherDevice(node)) {
|
||||||
nodes_to_delete.insert(&node);
|
nodes_to_delete.insert(&node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,8 +94,8 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
|||||||
|
|
||||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||||
Output c = ops::Identity(s.WithOpName("c").WithDevice("CPU:0"), b);
|
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||||
Output d = ops::Identity(s.WithOpName("d").WithDevice("GPU:0"), c);
|
Output d = ops::Identity(s.WithOpName("d"), c);
|
||||||
Output e = ops::AddN(s.WithOpName("e"), {d});
|
Output e = ops::AddN(s.WithOpName("e"), {d});
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
@ -119,9 +119,11 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
|
|||||||
EXPECT_EQ(NodeName(e.name()), new_e.name());
|
EXPECT_EQ(NodeName(e.name()), new_e.name());
|
||||||
|
|
||||||
EXPECT_EQ(1, new_e.input_size());
|
EXPECT_EQ(1, new_e.input_size());
|
||||||
EXPECT_EQ(NodeName(c.name()), new_e.input(0));
|
EXPECT_EQ(NodeName(b.name()), new_e.input(0));
|
||||||
EXPECT_EQ(1, new_d.input_size());
|
EXPECT_EQ(1, new_d.input_size());
|
||||||
EXPECT_EQ(NodeName(c.name()), new_d.input(0));
|
EXPECT_EQ(NodeName(b.name()), new_d.input(0));
|
||||||
|
EXPECT_EQ(1, new_c.input_size());
|
||||||
|
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
|
TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
|
||||||
@ -235,6 +237,38 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) {
|
|||||||
EXPECT_EQ(NodeName(c.name()), new_c.name());
|
EXPECT_EQ(NodeName(c.name()), new_c.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output c = ops::Const(s.WithOpName("c").WithDevice("/cpu:0"), 0.0f, {10, 10});
|
||||||
|
|
||||||
|
// Node i1 should be preserved.
|
||||||
|
Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c);
|
||||||
|
Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
|
||||||
|
Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
|
||||||
|
|
||||||
|
// Node i2 should be pruned since it resides on the sender's device.
|
||||||
|
Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c);
|
||||||
|
Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
|
||||||
|
Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
item.fetch = {"a1", "a2", "a3", "a4"};
|
||||||
|
|
||||||
|
ModelPruner pruner;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = pruner.Optimize(nullptr, item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
|
for (const auto& node : output.node()) {
|
||||||
|
if (node.name() == "a1" || node.name() == "a2") {
|
||||||
|
EXPECT_EQ("i1", node.input(0));
|
||||||
|
} else if (node.name() == "a3" || node.name() == "a4") {
|
||||||
|
EXPECT_EQ("c", node.input(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user