[XLA:GPU] [NFC] CollectiveScheduleLinearizer: test for dependent collectives

PiperOrigin-RevId: 361230404
Change-Id: Id20c99b3904871dc73b01e0d8a7f7a9237e76243
This commit is contained in:
George Karpenkov 2021-03-05 14:53:37 -08:00 committed by TensorFlower Gardener
parent 88de667855
commit 0341acc930
2 changed files with 26 additions and 1 deletions

View File

@ -52,7 +52,7 @@ StatusOr<bool> CollectivesScheduleLinearizer::Run(HloModule* module) {
for (HloInstruction* instruction : computation->instructions()) {
if (auto* next = DynCast<HloCollectiveInstruction>(instruction)) {
if (prev != nullptr && !reachability->IsConnected(next, prev)) {
// We check for reachability as we don't want to form a cycle.
// If prev and next are independent, enforce ordering.
TF_RETURN_IF_ERROR(prev->AddControlDependencyTo(next));
VLOG(1) << "Adding control dependency from " << prev->ToString()
<< " to " << next->ToString();

View File

@ -113,5 +113,30 @@ ENTRY entry {
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
}
TEST_F(CollectivesScheduleLinearizerTest, DependentCollectives) {
absl::string_view hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT out = f32[] add(a, b)
}
ENTRY entry {
p0 = f32[100] parameter(0), parameter_replication={false}
p1 = f32[100] parameter(1), parameter_replication={false}
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
c2 = f32[100] all-reduce(c1), replica_groups={}, to_apply=sum
ROOT out = f32[100] add(c1, c2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCollectivesSchedule(module.get());
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 0);
}
} // namespace
} // namespace xla