[XLA] Disallow empty replica groups when use_global_device_ids() == true
- Extend this check from AllGather to all other instruction that have use_global_device_ids PiperOrigin-RevId: 356318717 Change-Id: Ib5397454eed6f212e71e0aab45f52ad83e870715
This commit is contained in:
parent
879fd3bc67
commit
d7992c051e
@ -191,7 +191,8 @@ Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
|
||||
// This is just a minimal set of checks; some instructions may have additional
|
||||
// requirements. For example, all-to-all requires that all ReplicaGroups have
|
||||
// the same number of replicas, but that isn't checked here.
|
||||
static Status CheckReplicaGroups(HloInstruction* hlo) {
|
||||
static Status CheckReplicaGroups(HloInstruction* hlo,
|
||||
bool use_global_device_ids) {
|
||||
std::set<int64> replicas_seen;
|
||||
for (const ReplicaGroup& g : hlo->replica_groups()) {
|
||||
if (g.replica_ids().empty()) {
|
||||
@ -214,19 +215,24 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
|
||||
}
|
||||
}
|
||||
|
||||
// If use_global_device_ids() is set, replica_groups cannot be empty.
|
||||
// When the channel_id() or use_global_device_ids() is set, device ids in
|
||||
// ReplicaGroup config no longer only mean replica ids. So we skip the check
|
||||
// on the replica count.
|
||||
if (use_global_device_ids) {
|
||||
if (hlo->replica_groups().empty()) {
|
||||
return InternalError(
|
||||
"Replica group must be specified when use_global_device_ids is true");
|
||||
}
|
||||
// No need to check replica_count.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
|
||||
if (channel_instr->channel_id()) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
|
||||
if (all_reduce->use_global_device_ids()) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
int64 replica_count = hlo->GetModule()->config().replica_count();
|
||||
if (replica_count != 1 && !replicas_seen.empty() &&
|
||||
@ -242,14 +248,10 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
|
||||
|
||||
Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
||||
auto ag = Cast<HloAllGatherInstruction>(hlo);
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag));
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
|
||||
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
|
||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
|
||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
|
||||
if (ag->use_global_device_ids() && ag->replica_groups().empty()) {
|
||||
return InternalError(
|
||||
"Replica group must be specified when use_global_device_ids is true");
|
||||
}
|
||||
|
||||
int64 shard_count = CeilOfRatio(
|
||||
ag->shape().dimensions(ag->all_gather_dimension()),
|
||||
@ -274,20 +276,21 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
||||
shard_count));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(crs));
|
||||
Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
|
||||
auto ar = Cast<HloAllReduceInstruction>(hlo);
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
|
||||
|
||||
std::vector<const Shape*> operand_shapes;
|
||||
for (const HloInstruction* operand : crs->operands()) {
|
||||
for (const HloInstruction* operand : hlo->operands()) {
|
||||
operand_shapes.push_back(&operand->shape());
|
||||
}
|
||||
return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes));
|
||||
return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo));
|
||||
|
||||
auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
|
||||
|
||||
TF_RET_CHECK(all_to_all != nullptr);
|
||||
if (all_to_all->split_dimension()) {
|
||||
if (hlo->replica_groups().empty()) {
|
||||
|
@ -57,7 +57,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
Status HandleCholesky(HloInstruction* hlo) override;
|
||||
Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
Status HandleAllGather(HloInstruction* hlo) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
Status HandleAllReduce(HloInstruction* hlo) override;
|
||||
Status HandleAllToAll(HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
|
||||
|
@ -1292,5 +1292,29 @@ TEST_F(HloVerifierTest, ComparisonTypePred) {
|
||||
HasSubstr("Expected comparison type UNSIGNED"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY CRS {
|
||||
input = f32[8]{0} parameter(0)
|
||||
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
|
||||
use_global_device_ids=true, to_apply=add
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Replica group must be specified when "
|
||||
"use_global_device_ids is true"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user