[XLA] Disallow empty replica groups when use_global_device_ids() == true

- Extend this check from AllGather to all other instruction that have use_global_device_ids

PiperOrigin-RevId: 356318717
Change-Id: Ib5397454eed6f212e71e0aab45f52ad83e870715
This commit is contained in:
Rahul Joshi 2021-02-08 11:56:35 -08:00 committed by TensorFlower Gardener
parent 879fd3bc67
commit d7992c051e
3 changed files with 45 additions and 18 deletions

View File

@ -191,7 +191,8 @@ Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
// This is just a minimal set of checks; some instructions may have additional
// requirements. For example, all-to-all requires that all ReplicaGroups have
// the same number of replicas, but that isn't checked here.
static Status CheckReplicaGroups(HloInstruction* hlo) {
static Status CheckReplicaGroups(HloInstruction* hlo,
bool use_global_device_ids) {
std::set<int64> replicas_seen;
for (const ReplicaGroup& g : hlo->replica_groups()) {
if (g.replica_ids().empty()) {
@ -214,19 +215,24 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
}
}
// If use_global_device_ids() is set, replica_groups cannot be empty.
// When the channel_id() or use_global_device_ids() is set, device ids in
// ReplicaGroup config no longer only mean replica ids. So we skip the check
// on the replica count.
if (use_global_device_ids) {
if (hlo->replica_groups().empty()) {
return InternalError(
"Replica group must be specified when use_global_device_ids is true");
}
// No need to check replica_count.
return Status::OK();
}
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
if (channel_instr->channel_id()) {
return Status::OK();
}
}
if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
if (all_reduce->use_global_device_ids()) {
return Status::OK();
}
}
int64 replica_count = hlo->GetModule()->config().replica_count();
if (replica_count != 1 && !replicas_seen.empty() &&
@ -242,14 +248,10 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
auto ag = Cast<HloAllGatherInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
if (ag->use_global_device_ids() && ag->replica_groups().empty()) {
return InternalError(
"Replica group must be specified when use_global_device_ids is true");
}
int64 shard_count = CeilOfRatio(
ag->shape().dimensions(ag->all_gather_dimension()),
@ -274,20 +276,21 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
shard_count));
}
Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
TF_RETURN_IF_ERROR(CheckReplicaGroups(crs));
Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
auto ar = Cast<HloAllReduceInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : crs->operands()) {
for (const HloInstruction* operand : hlo->operands()) {
operand_shapes.push_back(&operand->shape());
}
return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes));
return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
}
Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo));
auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
TF_RET_CHECK(all_to_all != nullptr);
if (all_to_all->split_dimension()) {
if (hlo->replica_groups().empty()) {

View File

@ -57,7 +57,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleCholesky(HloInstruction* hlo) override;
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status HandleAllGather(HloInstruction* hlo) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleAllReduce(HloInstruction* hlo) override;
Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override;
Status HandleCollectivePermuteStart(HloInstruction* hlo) override;

View File

@ -1292,5 +1292,29 @@ TEST_F(HloVerifierTest, ComparisonTypePred) {
HasSubstr("Expected comparison type UNSIGNED"));
}
TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
use_global_device_ids=true, to_apply=add
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Replica group must be specified when "
"use_global_device_ids is true"));
}
} // namespace
} // namespace xla