Invert static ordering of collective ops when enabled.

PiperOrigin-RevId: 229929957
This commit is contained in:
Ayush Dubey 2019-01-18 08:08:57 -08:00 committed by TensorFlower Gardener
parent 8690cf7bf3
commit 7c65c591fb
2 changed files with 10 additions and 10 deletions

View File

@ -92,8 +92,8 @@ Status CreateControlDependencies(
const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
if (deps_i.find(instance_keys[j]) == deps_i.end() &&
deps_j.find(instance_keys[i]) == deps_j.end()) {
int src_idx = instance_keys[i] < instance_keys[j] ? i : j;
int dst_idx = instance_keys[i] < instance_keys[j] ? j : i;
int src_idx = instance_keys[i] > instance_keys[j] ? i : j;
int dst_idx = instance_keys[i] > instance_keys[j] ? j : i;
Node* src_node = collective_nodes[src_idx];
Node* dst_node = collective_nodes[dst_idx];
VLOG(1) << "Adding control dependency from node " << src_node->name()

View File

@ -137,18 +137,18 @@ std::unique_ptr<Graph> InitGraph() {
}
// Tests that in the graph created by `InitGraph`, exactly 2 control edges are
// added after calling `OrderCollectives`: c2_0 -> c3_0 and c2_1 -> c3_1.
// added after calling `OrderCollectives`: c3_0 -> c2_0 and c3_1 -> c2_1.
TEST(CollectiveOrderTest, SimpleOrder) {
std::unique_ptr<Graph> graph = InitGraph();
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
{{"c2_0", "c3_0"}, {"c2_1", "c3_1"}});
{{"c3_0", "c2_0"}, {"c3_1", "c2_1"}});
}
TEST(CollectiveOrderTest, SimpleOrderAttr) {
std::unique_ptr<Graph> graph = InitGraph();
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
VerifyAttrs(*graph, {{"c3_0", {2}}, {"c3_1", {2}}});
VerifyAttrs(*graph, {{"c2_0", {3}}, {"c2_1", {3}}});
}
// Initialize the following graph:
@ -185,12 +185,12 @@ std::unique_ptr<Graph> InitGraph2() {
}
// Tests that in the graph created by `InitGraph2`, we add the following control
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4. c2->c4 is
// edges after calling `OrderCollectives`: c4 -> c3, c3 -> c2. c4->c2 is
// pruned because it follows from the other two edges.
TEST(CollectiveOrderTest, SimpleOrder2) {
std::unique_ptr<Graph> graph = InitGraph2();
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c2", "c3"}, {"c3", "c4"}});
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c4", "c3"}, {"c3", "c2"}});
}
// Initialize the following graph:
@ -223,12 +223,12 @@ std::unique_ptr<Graph> InitGraphForPruning() {
return graph;
}
// Tests that in the graph created by `InitGraphForPruning`, we only add c1 ->
// c2, c2 -> c3, c3 -> c4, and other edges are pruned away.
// Tests that in the graph created by `InitGraphForPruning`, we only add c4 ->
// c3, c3 -> c2, c2 -> c1, and other edges are pruned away.
TEST(CollectiveOrderTest, Pruning) {
std::unique_ptr<Graph> graph = InitGraphForPruning();
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
VerifyAttrs(*graph, {{"c4", {3}}, {"c3", {2}}, {"c2", {1}}});
VerifyAttrs(*graph, {{"c3", {4}}, {"c2", {3}}, {"c1", {2}}});
}
} // namespace