[XLA:GPU] [NFC] CollectiveScheduleLinearizer: test for dependent collectives
PiperOrigin-RevId: 361230404 Change-Id: Id20c99b3904871dc73b01e0d8a7f7a9237e76243
This commit is contained in:
parent
88de667855
commit
0341acc930
@ -52,7 +52,7 @@ StatusOr<bool> CollectivesScheduleLinearizer::Run(HloModule* module) {
|
|||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
if (auto* next = DynCast<HloCollectiveInstruction>(instruction)) {
|
if (auto* next = DynCast<HloCollectiveInstruction>(instruction)) {
|
||||||
if (prev != nullptr && !reachability->IsConnected(next, prev)) {
|
if (prev != nullptr && !reachability->IsConnected(next, prev)) {
|
||||||
// We check for reachability as we don't want to form a cycle.
|
// If prev and next are independent, enforce ordering.
|
||||||
TF_RETURN_IF_ERROR(prev->AddControlDependencyTo(next));
|
TF_RETURN_IF_ERROR(prev->AddControlDependencyTo(next));
|
||||||
VLOG(1) << "Adding control dependency from " << prev->ToString()
|
VLOG(1) << "Adding control dependency from " << prev->ToString()
|
||||||
<< " to " << next->ToString();
|
<< " to " << next->ToString();
|
||||||
|
@ -113,5 +113,30 @@ ENTRY entry {
|
|||||||
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
|
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CollectivesScheduleLinearizerTest, DependentCollectives) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
sum {
|
||||||
|
a = f32[] parameter(0)
|
||||||
|
b = f32[] parameter(1)
|
||||||
|
ROOT out = f32[] add(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
p0 = f32[100] parameter(0), parameter_replication={false}
|
||||||
|
p1 = f32[100] parameter(1), parameter_replication={false}
|
||||||
|
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
|
||||||
|
c2 = f32[100] all-reduce(c1), replica_groups={}, to_apply=sum
|
||||||
|
ROOT out = f32[100] add(c1, c2)
|
||||||
|
}
|
||||||
|
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
InsertCollectivesSchedule(module.get());
|
||||||
|
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user