Verify replica count from AllReduce replica group config
PiperOrigin-RevId: 308628769 Change-Id: I636c90d9d153c8f5bf21a909ab65a47659465919
This commit is contained in:
parent
491b81b78d
commit
127aa2a6c0
@ -238,7 +238,7 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) {
|
|||||||
|
|
||||||
// Tests that AllReduce ops with different groups are not combined.
|
// Tests that AllReduce ops with different groups are not combined.
|
||||||
TEST_F(AllReduceCombinerTest, GroupAllReduce) {
|
TEST_F(AllReduceCombinerTest, GroupAllReduce) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4);
|
||||||
HloComputation::Builder b(TestName());
|
HloComputation::Builder b(TestName());
|
||||||
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
|
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
|
||||||
|
|
||||||
|
@ -78,8 +78,8 @@ test {
|
|||||||
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3)
|
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
ParseAndReturnVerifiedModule(kModuleStr));
|
kModuleStr, /*replica_count=*/8));
|
||||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
@ -114,8 +114,8 @@ test {
|
|||||||
ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum
|
ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
ParseAndReturnVerifiedModule(kModuleStr));
|
kModuleStr, /*replica_count=*/8));
|
||||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
@ -155,8 +155,8 @@ test {
|
|||||||
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2)
|
ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
ParseAndReturnVerifiedModule(kModuleStr));
|
kModuleStr, /*replica_count=*/8));
|
||||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
@ -447,8 +447,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -497,8 +498,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -565,8 +567,9 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -633,8 +636,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -675,8 +679,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -756,8 +761,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -809,8 +815,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -891,8 +898,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||||
/*spmd_partition=*/false);
|
/*spmd_partition=*/false);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
@ -929,8 +937,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||||
/*spmd_partition=*/true);
|
/*spmd_partition=*/true);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
@ -987,8 +996,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1062,8 +1072,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1110,8 +1121,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1180,8 +1192,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1224,8 +1237,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1312,8 +1326,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1363,8 +1378,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1452,8 +1468,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1502,8 +1519,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
auto crs_before =
|
auto crs_before =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_before = crs_before->replica_groups();
|
auto replica_groups_before = crs_before->replica_groups();
|
||||||
@ -1579,8 +1597,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||||
/*spmd_partition=*/false);
|
/*spmd_partition=*/false);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
@ -1616,8 +1635,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||||
/*spmd_partition=*/true);
|
/*spmd_partition=*/true);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
@ -1691,8 +1711,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||||
/*spmd_partition=*/false);
|
/*spmd_partition=*/false);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
@ -1719,8 +1740,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||||
/*spmd_partition=*/true);
|
/*spmd_partition=*/true);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
@ -1739,14 +1761,17 @@ HloModule foobar
|
|||||||
|
|
||||||
ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
|
ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
|
||||||
%p = f32[2,4] parameter(0), sharding={replicated}
|
%p = f32[2,4] parameter(0), sharding={replicated}
|
||||||
ROOT %all-reduce = f32[2,4] all-reduce(%p), replica_groups={{0,1}},
|
ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32,
|
||||||
to_apply=%sum.f32
|
replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}}
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
// Replacing replicated all-reduce is only triggered when there are enough
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
// replicas (currently > num_partitions * 8).
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/64,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32));
|
||||||
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/32,
|
||||||
/*spmd_partition=*/true);
|
/*spmd_partition=*/true);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
@ -1758,7 +1783,7 @@ ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
|
|||||||
auto ar = root->operand(0);
|
auto ar = root->operand(0);
|
||||||
auto divisor = root->operand(1)->operand(0);
|
auto divisor = root->operand(1)->operand(0);
|
||||||
EXPECT_TRUE(ar->channel_id());
|
EXPECT_TRUE(ar->channel_id());
|
||||||
EXPECT_TRUE(divisor->literal().IsAllFloat(4));
|
EXPECT_TRUE(divisor->literal().IsAllFloat(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) {
|
TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) {
|
||||||
@ -1782,8 +1807,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
|
||||||
/*spmd_partition=*/true);
|
/*spmd_partition=*/true);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
|
@ -275,7 +275,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2);
|
||||||
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||||
@ -304,7 +304,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
|
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2);
|
||||||
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||||
|
@ -42,6 +42,7 @@ using absl::string_view;
|
|||||||
struct TestData {
|
struct TestData {
|
||||||
string test_name;
|
string test_name;
|
||||||
string module_string;
|
string module_string;
|
||||||
|
int64 replica_count = 1;
|
||||||
bool enable_verification = true;
|
bool enable_verification = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1439,7 +1440,8 @@ ENTRY AllReduceWithSubgroups {
|
|||||||
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
|
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*replica_count=*/4,
|
||||||
},
|
},
|
||||||
// all-reduce with constrained layout
|
// all-reduce with constrained layout
|
||||||
{
|
{
|
||||||
@ -1501,7 +1503,8 @@ ENTRY AllToAllWithSubgroups {
|
|||||||
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
|
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*replica_count=*/4,
|
||||||
},
|
},
|
||||||
// collective-permute
|
// collective-permute
|
||||||
{
|
{
|
||||||
@ -1513,7 +1516,8 @@ ENTRY CollectivePermute {
|
|||||||
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
|
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*replica_count=*/4
|
||||||
},
|
},
|
||||||
// replica-id
|
// replica-id
|
||||||
{
|
{
|
||||||
@ -1686,16 +1690,19 @@ class HloParameterizedParserTest
|
|||||||
void ExpectEqual() {
|
void ExpectEqual() {
|
||||||
std::unique_ptr<HloModule> module;
|
std::unique_ptr<HloModule> module;
|
||||||
const string& original = GetParam().module_string;
|
const string& original = GetParam().module_string;
|
||||||
|
HloModuleConfig config;
|
||||||
|
config.set_replica_count(GetParam().replica_count);
|
||||||
if (GetParam().enable_verification) {
|
if (GetParam().enable_verification) {
|
||||||
auto verified_module = absl::make_unique<VerifiedHloModule>(
|
auto verified_module = absl::make_unique<VerifiedHloModule>(
|
||||||
GetParam().test_name, HloModuleConfig(),
|
GetParam().test_name, config,
|
||||||
/*verifier_layout_sensitive=*/false,
|
/*verifier_layout_sensitive=*/false,
|
||||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||||
ShapeUtil::ByteSizeOfElements);
|
ShapeUtil::ByteSizeOfElements);
|
||||||
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
|
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
|
||||||
module = std::move(verified_module);
|
module = std::move(verified_module);
|
||||||
} else {
|
} else {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
|
TF_ASSERT_OK_AND_ASSIGN(module,
|
||||||
|
ParseAndReturnUnverifiedModule(original, config));
|
||||||
}
|
}
|
||||||
if (proto_round_trip) {
|
if (proto_round_trip) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
|
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
|
||||||
|
@ -69,8 +69,8 @@ ENTRY entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
module_str, /*replica_count=*/4));
|
||||||
auto param = module->entry_computation()->parameter_instruction(0);
|
auto param = module->entry_computation()->parameter_instruction(0);
|
||||||
param->set_parameter_replicated_at_leaf_buffers(
|
param->set_parameter_replicated_at_leaf_buffers(
|
||||||
absl::Span<const bool>{false, true});
|
absl::Span<const bool>{false, true});
|
||||||
@ -149,8 +149,8 @@ ENTRY entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
module_str, /*replica_count=*/4));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
std::unique_ptr<HloReplicationAnalysis> analysis,
|
std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||||
HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
|
HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
|
||||||
@ -575,8 +575,8 @@ ENTRY entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
ParseAndReturnVerifiedModule(module_str));
|
module_str, /*replica_count=*/2));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||||
HloReplicationAnalysis::Run(
|
HloReplicationAnalysis::Run(
|
||||||
module.get(), /*cross_partition_spmd=*/false));
|
module.get(), /*cross_partition_spmd=*/false));
|
||||||
|
@ -210,6 +210,29 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
|
|||||||
hlo->ToString());
|
hlo->ToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When the channel_id() or use_global_device_ids() is set, device ids in
|
||||||
|
// ReplicaGroup config no longer only mean replica ids. So we skip the check
|
||||||
|
// on the replica count.
|
||||||
|
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
|
||||||
|
if (channel_instr->channel_id()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
|
||||||
|
if (all_reduce->use_global_device_ids()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 replica_count = hlo->GetModule()->config().replica_count();
|
||||||
|
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
|
||||||
|
return InternalError(
|
||||||
|
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
|
||||||
|
"contains %d replicas: %s",
|
||||||
|
replica_count, replicas_seen.size(), hlo->ToString());
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -859,8 +859,17 @@ string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
|
|||||||
return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
|
return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
|
||||||
|
int64 replica_count = 0;
|
||||||
|
for (auto group : replica_groups) {
|
||||||
|
replica_count += group.size();
|
||||||
|
}
|
||||||
|
return replica_count;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
||||||
std::vector<std::vector<int64>> replica_groups) {
|
std::vector<std::vector<int64>> replica_groups,
|
||||||
|
absl::optional<int64> replica_count = absl::nullopt) {
|
||||||
const char* kTemplate = R"(
|
const char* kTemplate = R"(
|
||||||
HloModule test
|
HloModule test
|
||||||
add {
|
add {
|
||||||
@ -872,8 +881,17 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
|||||||
p = f32[128]{0} parameter(0)
|
p = f32[128]{0} parameter(0)
|
||||||
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
|
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
|
||||||
})";
|
})";
|
||||||
return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
|
|
||||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
|
HloModuleConfig config;
|
||||||
|
if (replica_count) {
|
||||||
|
config.set_replica_count(*replica_count);
|
||||||
|
} else {
|
||||||
|
config.set_replica_count(ReplicaCount(replica_groups));
|
||||||
|
}
|
||||||
|
return ParseAndReturnUnverifiedModule(
|
||||||
|
absl::StrReplaceAll(
|
||||||
|
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
|
||||||
|
config);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
|
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
|
||||||
@ -907,6 +925,21 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
|
|||||||
HasSubstr("Replica 4 is not named"));
|
HasSubstr("Replica 4 is not named"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
|
||||||
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
|
HasSubstr("Replica count in HloModuleConfig is 8, but "
|
||||||
|
"ReplicaGroup config contains 2 replicas"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
|
||||||
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
|
HasSubstr("Replica count in HloModuleConfig is 2, but "
|
||||||
|
"ReplicaGroup config contains 4 replicas"));
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
||||||
std::vector<std::vector<int64>> replica_groups) {
|
std::vector<std::vector<int64>> replica_groups) {
|
||||||
const char* kTemplate = R"(
|
const char* kTemplate = R"(
|
||||||
@ -921,8 +954,12 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
|||||||
p1 = f32[128]{0} parameter(1)
|
p1 = f32[128]{0} parameter(1)
|
||||||
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
|
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
|
||||||
})";
|
})";
|
||||||
return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
|
HloModuleConfig config;
|
||||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
|
config.set_replica_count(ReplicaCount(replica_groups));
|
||||||
|
return ParseAndReturnUnverifiedModule(
|
||||||
|
absl::StrReplaceAll(
|
||||||
|
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
|
||||||
|
config);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
|
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
|
||||||
@ -966,8 +1003,10 @@ TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
|
|||||||
source_target_pairs={{0,1}, {0,2}, {1,0}}
|
source_target_pairs={{0,1}, {0,2}, {1,0}}
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
HloModuleConfig config;
|
||||||
|
config.set_replica_count(3);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
ParseAndReturnUnverifiedModule(kModuleStr, config));
|
||||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
HasSubstr("Source 0 appears more than once"));
|
HasSubstr("Source 0 appears more than once"));
|
||||||
}
|
}
|
||||||
|
@ -117,16 +117,18 @@ std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
|
std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
|
||||||
const string& name) {
|
const string& name, int64 replica_count) {
|
||||||
return absl::make_unique<VerifiedHloModule>(
|
return absl::make_unique<VerifiedHloModule>(
|
||||||
name, GetModuleConfigForTest(), verifier_layout_sensitive_,
|
name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
|
||||||
allow_mixed_precision_in_hlo_verifier_,
|
allow_mixed_precision_in_hlo_verifier_,
|
||||||
backend().compiler()->ShapeSizeBytesFunction());
|
backend().compiler()->ShapeSizeBytesFunction());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||||
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
|
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
|
||||||
return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest());
|
int64 replica_count) {
|
||||||
|
return ParseAndReturnVerifiedModule(hlo_text,
|
||||||
|
GetModuleConfigForTest(replica_count));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||||
|
@ -84,11 +84,11 @@ class HloTestBase : public ::testing::Test {
|
|||||||
// Like CreateNewUnverifiedModule, except the HloModule returned here runs the
|
// Like CreateNewUnverifiedModule, except the HloModule returned here runs the
|
||||||
// HLO verifier on destruction.
|
// HLO verifier on destruction.
|
||||||
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
|
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
|
||||||
const string& name = TestName());
|
const string& name = TestName(), int64 replica_count = 1);
|
||||||
|
|
||||||
// Parses the given string and returns module as a VerifiedHloModule.
|
// Parses the given string and returns module as a VerifiedHloModule.
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
absl::string_view hlo_text);
|
absl::string_view hlo_text, int64 replica_count = 1);
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
absl::string_view hlo_text, const HloModuleConfig& config);
|
absl::string_view hlo_text, const HloModuleConfig& config);
|
||||||
|
|
||||||
@ -130,9 +130,10 @@ class HloTestBase : public ::testing::Test {
|
|||||||
virtual DebugOptions GetDebugOptionsForTest();
|
virtual DebugOptions GetDebugOptionsForTest();
|
||||||
|
|
||||||
// Gets an HloModuleConfig with options appropriate for tests.
|
// Gets an HloModuleConfig with options appropriate for tests.
|
||||||
HloModuleConfig GetModuleConfigForTest() {
|
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) {
|
||||||
HloModuleConfig config;
|
HloModuleConfig config;
|
||||||
config.set_debug_options(GetDebugOptionsForTest());
|
config.set_debug_options(GetDebugOptionsForTest());
|
||||||
|
config.set_replica_count(replica_count);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user