From d7992c051ee067d3bf759b60184c3728f4b75f26 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Mon, 8 Feb 2021 11:56:35 -0800 Subject: [PATCH] [XLA] Disallow empty replica groups when use_global_device_ids() == true - Extend this check from AllGather to all other instruction that have use_global_device_ids PiperOrigin-RevId: 356318717 Change-Id: Ib5397454eed6f212e71e0aab45f52ad83e870715 --- .../compiler/xla/service/hlo_verifier.cc | 37 ++++++++++--------- .../compiler/xla/service/hlo_verifier.h | 2 +- .../compiler/xla/service/hlo_verifier_test.cc | 24 ++++++++++++ 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 2c8736ad42c..40ecd80e4c9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -191,7 +191,8 @@ Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) { // This is just a minimal set of checks; some instructions may have additional // requirements. For example, all-to-all requires that all ReplicaGroups have // the same number of replicas, but that isn't checked here. -static Status CheckReplicaGroups(HloInstruction* hlo) { +static Status CheckReplicaGroups(HloInstruction* hlo, + bool use_global_device_ids) { std::set replicas_seen; for (const ReplicaGroup& g : hlo->replica_groups()) { if (g.replica_ids().empty()) { @@ -214,19 +215,24 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { } } + // If use_global_device_ids() is set, replica_groups cannot be empty. // When the channel_id() or use_global_device_ids() is set, device ids in // ReplicaGroup config no longer only mean replica ids. So we skip the check // on the replica count. + if (use_global_device_ids) { + if (hlo->replica_groups().empty()) { + return InternalError( + "Replica group must be specified when use_global_device_ids is true"); + } + // No need to check replica_count. + return Status::OK(); + } + if (auto channel_instr = DynCast(hlo)) { if (channel_instr->channel_id()) { return Status::OK(); } } - if (auto all_reduce = DynCast(hlo)) { - if (all_reduce->use_global_device_ids()) { - return Status::OK(); - } - } int64 replica_count = hlo->GetModule()->config().replica_count(); if (replica_count != 1 && !replicas_seen.empty() && @@ -242,14 +248,10 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { auto ag = Cast(hlo); - TF_RETURN_IF_ERROR(CheckReplicaGroups(ag)); + TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids())); TF_RET_CHECK(ag->all_gather_dimension() >= 0); TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank()); TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank()); - if (ag->use_global_device_ids() && ag->replica_groups().empty()) { - return InternalError( - "Replica group must be specified when use_global_device_ids is true"); - } int64 shard_count = CeilOfRatio( ag->shape().dimensions(ag->all_gather_dimension()), @@ -274,20 +276,21 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { shard_count)); } -Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { - TF_RETURN_IF_ERROR(CheckReplicaGroups(crs)); +Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) { + auto ar = Cast(hlo); + TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids())); std::vector operand_shapes; - for (const HloInstruction* operand : crs->operands()) { + for (const HloInstruction* operand : hlo->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes)); + return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes)); } Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo)); - auto* all_to_all = Cast(hlo); + TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false)); + TF_RET_CHECK(all_to_all != nullptr); if (all_to_all->split_dimension()) { if (hlo->replica_groups().empty()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 03fca5938ff..980fdd170b4 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -57,7 +57,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleCholesky(HloInstruction* hlo) override; Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleAllGather(HloInstruction* hlo) override; - Status HandleAllReduce(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* hlo) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleCollectivePermuteStart(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index c6c09e3dee1..7a95cdba7c3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1292,5 +1292,29 @@ TEST_F(HloVerifierTest, ComparisonTypePred) { HasSubstr("Expected comparison type UNSIGNED")); } +TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) { + const char* const hlo_string = R"( + HloModule Module + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, + use_global_device_ids=true, to_apply=add + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Replica group must be specified when " + "use_global_device_ids is true")); +} + } // namespace } // namespace xla