Invert static ordering of collective ops when enabled.
PiperOrigin-RevId: 229929957
This commit is contained in:
parent
8690cf7bf3
commit
7c65c591fb
@ -92,8 +92,8 @@ Status CreateControlDependencies(
|
|||||||
const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
|
const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
|
||||||
if (deps_i.find(instance_keys[j]) == deps_i.end() &&
|
if (deps_i.find(instance_keys[j]) == deps_i.end() &&
|
||||||
deps_j.find(instance_keys[i]) == deps_j.end()) {
|
deps_j.find(instance_keys[i]) == deps_j.end()) {
|
||||||
int src_idx = instance_keys[i] < instance_keys[j] ? i : j;
|
int src_idx = instance_keys[i] > instance_keys[j] ? i : j;
|
||||||
int dst_idx = instance_keys[i] < instance_keys[j] ? j : i;
|
int dst_idx = instance_keys[i] > instance_keys[j] ? j : i;
|
||||||
Node* src_node = collective_nodes[src_idx];
|
Node* src_node = collective_nodes[src_idx];
|
||||||
Node* dst_node = collective_nodes[dst_idx];
|
Node* dst_node = collective_nodes[dst_idx];
|
||||||
VLOG(1) << "Adding control dependency from node " << src_node->name()
|
VLOG(1) << "Adding control dependency from node " << src_node->name()
|
||||||
|
@ -137,18 +137,18 @@ std::unique_ptr<Graph> InitGraph() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tests that in the graph created by `InitGraph`, exactly 2 control edges are
|
// Tests that in the graph created by `InitGraph`, exactly 2 control edges are
|
||||||
// added after calling `OrderCollectives`: c2_0 -> c3_0 and c2_1 -> c3_1.
|
// added after calling `OrderCollectives`: c3_0 -> c2_0 and c3_1 -> c2_1.
|
||||||
TEST(CollectiveOrderTest, SimpleOrder) {
|
TEST(CollectiveOrderTest, SimpleOrder) {
|
||||||
std::unique_ptr<Graph> graph = InitGraph();
|
std::unique_ptr<Graph> graph = InitGraph();
|
||||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
|
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
|
||||||
VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
|
VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
|
||||||
{{"c2_0", "c3_0"}, {"c2_1", "c3_1"}});
|
{{"c3_0", "c2_0"}, {"c3_1", "c2_1"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CollectiveOrderTest, SimpleOrderAttr) {
|
TEST(CollectiveOrderTest, SimpleOrderAttr) {
|
||||||
std::unique_ptr<Graph> graph = InitGraph();
|
std::unique_ptr<Graph> graph = InitGraph();
|
||||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
|
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
|
||||||
VerifyAttrs(*graph, {{"c3_0", {2}}, {"c3_1", {2}}});
|
VerifyAttrs(*graph, {{"c2_0", {3}}, {"c2_1", {3}}});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the following graph:
|
// Initialize the following graph:
|
||||||
@ -185,12 +185,12 @@ std::unique_ptr<Graph> InitGraph2() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tests that in the graph created by `InitGraph2`, we add the following control
|
// Tests that in the graph created by `InitGraph2`, we add the following control
|
||||||
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4. c2->c4 is
|
// edges after calling `OrderCollectives`: c4 -> c3, c3 -> c2. c4->c2 is
|
||||||
// pruned because it follows from the other two edges.
|
// pruned because it follows from the other two edges.
|
||||||
TEST(CollectiveOrderTest, SimpleOrder2) {
|
TEST(CollectiveOrderTest, SimpleOrder2) {
|
||||||
std::unique_ptr<Graph> graph = InitGraph2();
|
std::unique_ptr<Graph> graph = InitGraph2();
|
||||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
|
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
|
||||||
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c2", "c3"}, {"c3", "c4"}});
|
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c4", "c3"}, {"c3", "c2"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the following graph:
|
// Initialize the following graph:
|
||||||
@ -223,12 +223,12 @@ std::unique_ptr<Graph> InitGraphForPruning() {
|
|||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that in the graph created by `InitGraphForPruning`, we only add c1 ->
|
// Tests that in the graph created by `InitGraphForPruning`, we only add c4 ->
|
||||||
// c2, c2 -> c3, c3 -> c4, and other edges are pruned away.
|
// c3, c3 -> c2, c2 -> c1, and other edges are pruned away.
|
||||||
TEST(CollectiveOrderTest, Pruning) {
|
TEST(CollectiveOrderTest, Pruning) {
|
||||||
std::unique_ptr<Graph> graph = InitGraphForPruning();
|
std::unique_ptr<Graph> graph = InitGraphForPruning();
|
||||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
|
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
|
||||||
VerifyAttrs(*graph, {{"c4", {3}}, {"c3", {2}}, {"c2", {1}}});
|
VerifyAttrs(*graph, {{"c3", {4}}, {"c2", {3}}, {"c1", {2}}});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user