[XLA:GPU] [NFC] CollectiveScheduleLinearizer: test for dependent collectives
PiperOrigin-RevId: 361230404 Change-Id: Id20c99b3904871dc73b01e0d8a7f7a9237e76243
This commit is contained in:
parent
88de667855
commit
0341acc930
@ -52,7 +52,7 @@ StatusOr<bool> CollectivesScheduleLinearizer::Run(HloModule* module) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (auto* next = DynCast<HloCollectiveInstruction>(instruction)) {
|
||||
if (prev != nullptr && !reachability->IsConnected(next, prev)) {
|
||||
// We check for reachability as we don't want to form a cycle.
|
||||
// If prev and next are independent, enforce ordering.
|
||||
TF_RETURN_IF_ERROR(prev->AddControlDependencyTo(next));
|
||||
VLOG(1) << "Adding control dependency from " << prev->ToString()
|
||||
<< " to " << next->ToString();
|
||||
|
@ -113,5 +113,30 @@ ENTRY entry {
|
||||
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
|
||||
}
|
||||
|
||||
TEST_F(CollectivesScheduleLinearizerTest, DependentCollectives) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT out = f32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[100] parameter(0), parameter_replication={false}
|
||||
p1 = f32[100] parameter(1), parameter_replication={false}
|
||||
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
|
||||
c2 = f32[100] all-reduce(c1), replica_groups={}, to_apply=sum
|
||||
ROOT out = f32[100] add(c1, c2)
|
||||
}
|
||||
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCollectivesSchedule(module.get());
|
||||
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user