Verify replica count from AllReduce replica group config

PiperOrigin-RevId: 308628769
Change-Id: I636c90d9d153c8f5bf21a909ab65a47659465919
This commit is contained in:
HyoukJoong Lee 2020-04-27 08:50:29 -07:00 committed by TensorFlower Gardener
parent 491b81b78d
commit 127aa2a6c0
10 changed files with 183 additions and 85 deletions

View File

@ -238,7 +238,7 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) {
// Tests that AllReduce ops with different groups are not combined.
TEST_F(AllReduceCombinerTest, GroupAllReduce) {
auto module = CreateNewVerifiedModule();
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4);
HloComputation::Builder b(TestName());
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());

View File

@ -78,8 +78,8 @@ test {
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kModuleStr, /*replica_count=*/8));
AllReduceSimplifier simplifier(/*replica_count=*/8);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(
@ -114,8 +114,8 @@ test {
ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kModuleStr, /*replica_count=*/8));
AllReduceSimplifier simplifier(/*replica_count=*/8);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -155,8 +155,8 @@ test {
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kModuleStr, /*replica_count=*/8));
AllReduceSimplifier simplifier(/*replica_count=*/8);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(

View File

@ -447,8 +447,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -497,8 +498,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -565,8 +567,9 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -633,8 +636,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -675,8 +679,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -756,8 +761,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -809,8 +815,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -891,8 +898,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
@ -929,8 +937,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
@ -987,8 +996,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1062,8 +1072,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1110,8 +1121,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1180,8 +1192,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1224,8 +1237,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1312,8 +1326,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1363,8 +1378,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1452,8 +1468,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1502,8 +1519,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
@ -1579,8 +1597,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
@ -1616,8 +1635,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
@ -1691,8 +1711,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
@ -1719,8 +1740,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
@ -1739,14 +1761,17 @@ HloModule foobar
ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
%p = f32[2,4] parameter(0), sharding={replicated}
ROOT %all-reduce = f32[2,4] all-reduce(%p), replica_groups={{0,1}},
to_apply=%sum.f32
ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32,
replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/64,
// Replacing replicated all-reduce is only triggered when there are enough
// replicas (currently > num_partitions * 8).
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/32,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
@ -1758,7 +1783,7 @@ ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
auto ar = root->operand(0);
auto divisor = root->operand(1)->operand(0);
EXPECT_TRUE(ar->channel_id());
EXPECT_TRUE(divisor->literal().IsAllFloat(4));
EXPECT_TRUE(divisor->literal().IsAllFloat(2));
}
TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) {
@ -1782,8 +1807,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();

View File

@ -275,7 +275,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
auto module = CreateNewVerifiedModule();
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2);
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
@ -304,7 +304,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
auto module = CreateNewVerifiedModule();
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2);
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});

View File

@ -42,6 +42,7 @@ using absl::string_view;
struct TestData {
string test_name;
string module_string;
int64 replica_count = 1;
bool enable_verification = true;
};
@ -1439,7 +1440,8 @@ ENTRY AllReduceWithSubgroups {
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
}
)"
)",
/*replica_count=*/4,
},
// all-reduce with constrained layout
{
@ -1501,7 +1503,8 @@ ENTRY AllToAllWithSubgroups {
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
}
)"
)",
/*replica_count=*/4,
},
// collective-permute
{
@ -1513,7 +1516,8 @@ ENTRY CollectivePermute {
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
}
)"
)",
/*replica_count=*/4
},
// replica-id
{
@ -1686,16 +1690,19 @@ class HloParameterizedParserTest
void ExpectEqual() {
std::unique_ptr<HloModule> module;
const string& original = GetParam().module_string;
HloModuleConfig config;
config.set_replica_count(GetParam().replica_count);
if (GetParam().enable_verification) {
auto verified_module = absl::make_unique<VerifiedHloModule>(
GetParam().test_name, HloModuleConfig(),
GetParam().test_name, config,
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
module = std::move(verified_module);
} else {
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnUnverifiedModule(original, config));
}
if (proto_round_trip) {
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(

View File

@ -69,8 +69,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
module_str, /*replica_count=*/4));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{false, true});
@ -149,8 +149,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
module_str, /*replica_count=*/4));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
@ -575,8 +575,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
module_str, /*replica_count=*/2));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));

View File

@ -210,6 +210,29 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
hlo->ToString());
}
}
// When the channel_id() or use_global_device_ids() is set, device ids in
// ReplicaGroup config no longer only mean replica ids. So we skip the check
// on the replica count.
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
if (channel_instr->channel_id()) {
return Status::OK();
}
}
if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
if (all_reduce->use_global_device_ids()) {
return Status::OK();
}
}
int64 replica_count = hlo->GetModule()->config().replica_count();
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
return InternalError(
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
"contains %d replicas: %s",
replica_count, replicas_seen.size(), hlo->ToString());
}
return Status::OK();
}

View File

@ -859,8 +859,17 @@ string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
}
int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
int64 replica_count = 0;
for (auto group : replica_groups) {
replica_count += group.size();
}
return replica_count;
}
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
std::vector<std::vector<int64>> replica_groups) {
std::vector<std::vector<int64>> replica_groups,
absl::optional<int64> replica_count = absl::nullopt) {
const char* kTemplate = R"(
HloModule test
add {
@ -872,8 +881,17 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
p = f32[128]{0} parameter(0)
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
})";
return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
HloModuleConfig config;
if (replica_count) {
config.set_replica_count(*replica_count);
} else {
config.set_replica_count(ReplicaCount(replica_groups));
}
return ParseAndReturnUnverifiedModule(
absl::StrReplaceAll(
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
config);
}
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
@ -907,6 +925,21 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
HasSubstr("Replica 4 is not named"));
}
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica count in HloModuleConfig is 8, but "
"ReplicaGroup config contains 2 replicas"));
}
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica count in HloModuleConfig is 2, but "
"ReplicaGroup config contains 4 replicas"));
}
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
std::vector<std::vector<int64>> replica_groups) {
const char* kTemplate = R"(
@ -921,8 +954,12 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
p1 = f32[128]{0} parameter(1)
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
})";
return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
HloModuleConfig config;
config.set_replica_count(ReplicaCount(replica_groups));
return ParseAndReturnUnverifiedModule(
absl::StrReplaceAll(
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
config);
}
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
@ -966,8 +1003,10 @@ TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
source_target_pairs={{0,1}, {0,2}, {1,0}}
}
)";
HloModuleConfig config;
config.set_replica_count(3);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnUnverifiedModule(kModuleStr, config));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Source 0 appears more than once"));
}

View File

@ -117,16 +117,18 @@ std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
}
std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
const string& name) {
const string& name, int64 replica_count) {
return absl::make_unique<VerifiedHloModule>(
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_,
backend().compiler()->ShapeSizeBytesFunction());
}
StatusOr<std::unique_ptr<VerifiedHloModule>>
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest());
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
int64 replica_count) {
return ParseAndReturnVerifiedModule(hlo_text,
GetModuleConfigForTest(replica_count));
}
StatusOr<std::unique_ptr<VerifiedHloModule>>

View File

@ -84,11 +84,11 @@ class HloTestBase : public ::testing::Test {
// Like CreateNewUnverifiedModule, except the HloModule returned here runs the
// HLO verifier on destruction.
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
const string& name = TestName());
const string& name = TestName(), int64 replica_count = 1);
// Parses the given string and returns module as a VerifiedHloModule.
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text);
absl::string_view hlo_text, int64 replica_count = 1);
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config);
@ -130,9 +130,10 @@ class HloTestBase : public ::testing::Test {
virtual DebugOptions GetDebugOptionsForTest();
// Gets an HloModuleConfig with options appropriate for tests.
HloModuleConfig GetModuleConfigForTest() {
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
config.set_replica_count(replica_count);
return config;
}