Verify replica count from AllReduce replica group config
PiperOrigin-RevId: 308628769 Change-Id: I636c90d9d153c8f5bf21a909ab65a47659465919
This commit is contained in:
parent
491b81b78d
commit
127aa2a6c0
@ -238,7 +238,7 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) {
|
||||
|
||||
// Tests that AllReduce ops with different groups are not combined.
|
||||
TEST_F(AllReduceCombinerTest, GroupAllReduce) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4);
|
||||
HloComputation::Builder b(TestName());
|
||||
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
|
||||
|
@ -78,8 +78,8 @@ test {
|
||||
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
kModuleStr, /*replica_count=*/8));
|
||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
@ -114,8 +114,8 @@ test {
|
||||
ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
kModuleStr, /*replica_count=*/8));
|
||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -155,8 +155,8 @@ test {
|
||||
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
kModuleStr, /*replica_count=*/8));
|
||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
|
@ -447,8 +447,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -497,8 +498,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -565,8 +567,9 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -633,8 +636,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -675,8 +679,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -756,8 +761,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -809,8 +815,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -891,8 +898,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
@ -929,8 +937,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
@ -987,8 +996,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1062,8 +1072,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1110,8 +1121,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1180,8 +1192,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1224,8 +1237,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1312,8 +1326,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1363,8 +1378,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1452,8 +1468,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1502,8 +1519,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
@ -1579,8 +1597,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
@ -1616,8 +1635,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
@ -1691,8 +1711,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
@ -1719,8 +1740,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
@ -1739,14 +1761,17 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
|
||||
%p = f32[2,4] parameter(0), sharding={replicated}
|
||||
ROOT %all-reduce = f32[2,4] all-reduce(%p), replica_groups={{0,1}},
|
||||
to_apply=%sum.f32
|
||||
ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32,
|
||||
replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/64,
|
||||
// Replacing replicated all-reduce is only triggered when there are enough
|
||||
// replicas (currently > num_partitions * 8).
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/32,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
@ -1758,7 +1783,7 @@ ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
|
||||
auto ar = root->operand(0);
|
||||
auto divisor = root->operand(1)->operand(0);
|
||||
EXPECT_TRUE(ar->channel_id());
|
||||
EXPECT_TRUE(divisor->literal().IsAllFloat(4));
|
||||
EXPECT_TRUE(divisor->literal().IsAllFloat(2));
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) {
|
||||
@ -1782,8 +1807,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
|
@ -275,7 +275,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2);
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
@ -304,7 +304,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2);
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
|
@ -42,6 +42,7 @@ using absl::string_view;
|
||||
struct TestData {
|
||||
string test_name;
|
||||
string module_string;
|
||||
int64 replica_count = 1;
|
||||
bool enable_verification = true;
|
||||
};
|
||||
|
||||
@ -1439,7 +1440,8 @@ ENTRY AllReduceWithSubgroups {
|
||||
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
|
||||
}
|
||||
|
||||
)"
|
||||
)",
|
||||
/*replica_count=*/4,
|
||||
},
|
||||
// all-reduce with constrained layout
|
||||
{
|
||||
@ -1501,7 +1503,8 @@ ENTRY AllToAllWithSubgroups {
|
||||
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
|
||||
}
|
||||
|
||||
)"
|
||||
)",
|
||||
/*replica_count=*/4,
|
||||
},
|
||||
// collective-permute
|
||||
{
|
||||
@ -1513,7 +1516,8 @@ ENTRY CollectivePermute {
|
||||
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||
}
|
||||
|
||||
)"
|
||||
)",
|
||||
/*replica_count=*/4
|
||||
},
|
||||
// replica-id
|
||||
{
|
||||
@ -1686,16 +1690,19 @@ class HloParameterizedParserTest
|
||||
void ExpectEqual() {
|
||||
std::unique_ptr<HloModule> module;
|
||||
const string& original = GetParam().module_string;
|
||||
HloModuleConfig config;
|
||||
config.set_replica_count(GetParam().replica_count);
|
||||
if (GetParam().enable_verification) {
|
||||
auto verified_module = absl::make_unique<VerifiedHloModule>(
|
||||
GetParam().test_name, HloModuleConfig(),
|
||||
GetParam().test_name, config,
|
||||
/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||
ShapeUtil::ByteSizeOfElements);
|
||||
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
|
||||
module = std::move(verified_module);
|
||||
} else {
|
||||
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
|
||||
TF_ASSERT_OK_AND_ASSIGN(module,
|
||||
ParseAndReturnUnverifiedModule(original, config));
|
||||
}
|
||||
if (proto_round_trip) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
|
||||
|
@ -69,8 +69,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
module_str, /*replica_count=*/4));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{false, true});
|
||||
@ -149,8 +149,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
module_str, /*replica_count=*/4));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
|
||||
@ -575,8 +575,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
module_str, /*replica_count=*/2));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
|
@ -210,6 +210,29 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
|
||||
hlo->ToString());
|
||||
}
|
||||
}
|
||||
|
||||
// When the channel_id() or use_global_device_ids() is set, device ids in
|
||||
// ReplicaGroup config no longer only mean replica ids. So we skip the check
|
||||
// on the replica count.
|
||||
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
|
||||
if (channel_instr->channel_id()) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
|
||||
if (all_reduce->use_global_device_ids()) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
int64 replica_count = hlo->GetModule()->config().replica_count();
|
||||
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
|
||||
return InternalError(
|
||||
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
|
||||
"contains %d replicas: %s",
|
||||
replica_count, replicas_seen.size(), hlo->ToString());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -859,8 +859,17 @@ string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
|
||||
return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
|
||||
}
|
||||
|
||||
int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
|
||||
int64 replica_count = 0;
|
||||
for (auto group : replica_groups) {
|
||||
replica_count += group.size();
|
||||
}
|
||||
return replica_count;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
||||
std::vector<std::vector<int64>> replica_groups) {
|
||||
std::vector<std::vector<int64>> replica_groups,
|
||||
absl::optional<int64> replica_count = absl::nullopt) {
|
||||
const char* kTemplate = R"(
|
||||
HloModule test
|
||||
add {
|
||||
@ -872,8 +881,17 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
||||
p = f32[128]{0} parameter(0)
|
||||
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
|
||||
})";
|
||||
return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
|
||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
|
||||
|
||||
HloModuleConfig config;
|
||||
if (replica_count) {
|
||||
config.set_replica_count(*replica_count);
|
||||
} else {
|
||||
config.set_replica_count(ReplicaCount(replica_groups));
|
||||
}
|
||||
return ParseAndReturnUnverifiedModule(
|
||||
absl::StrReplaceAll(
|
||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
|
||||
config);
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
|
||||
@ -907,6 +925,21 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
|
||||
HasSubstr("Replica 4 is not named"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
|
||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||
HasSubstr("Replica count in HloModuleConfig is 8, but "
|
||||
"ReplicaGroup config contains 2 replicas"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
|
||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||
HasSubstr("Replica count in HloModuleConfig is 2, but "
|
||||
"ReplicaGroup config contains 4 replicas"));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
||||
std::vector<std::vector<int64>> replica_groups) {
|
||||
const char* kTemplate = R"(
|
||||
@ -921,8 +954,12 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
||||
p1 = f32[128]{0} parameter(1)
|
||||
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
|
||||
})";
|
||||
return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
|
||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
|
||||
HloModuleConfig config;
|
||||
config.set_replica_count(ReplicaCount(replica_groups));
|
||||
return ParseAndReturnUnverifiedModule(
|
||||
absl::StrReplaceAll(
|
||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
|
||||
config);
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
|
||||
@ -966,8 +1003,10 @@ TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
|
||||
source_target_pairs={{0,1}, {0,2}, {1,0}}
|
||||
}
|
||||
)";
|
||||
HloModuleConfig config;
|
||||
config.set_replica_count(3);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnUnverifiedModule(kModuleStr, config));
|
||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||
HasSubstr("Source 0 appears more than once"));
|
||||
}
|
||||
|
@ -117,16 +117,18 @@ std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
|
||||
}
|
||||
|
||||
std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
|
||||
const string& name) {
|
||||
const string& name, int64 replica_count) {
|
||||
return absl::make_unique<VerifiedHloModule>(
|
||||
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
|
||||
name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
|
||||
allow_mixed_precision_in_hlo_verifier_,
|
||||
backend().compiler()->ShapeSizeBytesFunction());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
|
||||
return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest());
|
||||
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
|
||||
int64 replica_count) {
|
||||
return ParseAndReturnVerifiedModule(hlo_text,
|
||||
GetModuleConfigForTest(replica_count));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||
|
@ -84,11 +84,11 @@ class HloTestBase : public ::testing::Test {
|
||||
// Like CreateNewUnverifiedModule, except the HloModule returned here runs the
|
||||
// HLO verifier on destruction.
|
||||
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
|
||||
const string& name = TestName());
|
||||
const string& name = TestName(), int64 replica_count = 1);
|
||||
|
||||
// Parses the given string and returns module as a VerifiedHloModule.
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text);
|
||||
absl::string_view hlo_text, int64 replica_count = 1);
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text, const HloModuleConfig& config);
|
||||
|
||||
@ -130,9 +130,10 @@ class HloTestBase : public ::testing::Test {
|
||||
virtual DebugOptions GetDebugOptionsForTest();
|
||||
|
||||
// Gets an HloModuleConfig with options appropriate for tests.
|
||||
HloModuleConfig GetModuleConfigForTest() {
|
||||
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(GetDebugOptionsForTest());
|
||||
config.set_replica_count(replica_count);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user