From 127aa2a6c04603ab15626f6147a48ae734d2c520 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Mon, 27 Apr 2020 08:50:29 -0700 Subject: [PATCH] Verify replica count from AllReduce replica group config PiperOrigin-RevId: 308628769 Change-Id: I636c90d9d153c8f5bf21a909ab65a47659465919 --- .../xla/service/all_reduce_combiner_test.cc | 2 +- .../xla/service/all_reduce_simplifier_test.cc | 12 +- .../xla/service/ar_crs_combiner_test.cc | 130 +++++++++++------- .../service/bfloat16_normalization_test.cc | 4 +- .../compiler/xla/service/hlo_parser_test.cc | 17 ++- .../service/hlo_replication_analysis_test.cc | 12 +- .../compiler/xla/service/hlo_verifier.cc | 23 ++++ .../compiler/xla/service/hlo_verifier_test.cc | 51 ++++++- .../compiler/xla/tests/hlo_test_base.cc | 10 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 7 +- 10 files changed, 183 insertions(+), 85 deletions(-) diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc b/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc index b486612ff83..0b41f374900 100644 --- a/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc +++ b/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc @@ -238,7 +238,7 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) { // Tests that AllReduce ops with different groups are not combined. TEST_F(AllReduceCombinerTest, GroupAllReduce) { - auto module = CreateNewVerifiedModule(); + auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4); HloComputation::Builder b(TestName()); HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get()); diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc index 2e03e67c59c..4914836b34a 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc @@ -78,8 +78,8 @@ test { ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); AllReduceSimplifier simplifier(/*replica_count=*/8); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT( @@ -114,8 +114,8 @@ test { ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); AllReduceSimplifier simplifier(/*replica_count=*/8); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -155,8 +155,8 @@ test { ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); AllReduceSimplifier simplifier(/*replica_count=*/8); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index a02d5a86a27..bfa8f1020e5 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -447,8 +447,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -497,8 +498,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -565,8 +567,9 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -633,8 +636,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -675,8 +679,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -756,8 +761,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -809,8 +815,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -891,8 +898,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -929,8 +937,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -987,8 +996,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1062,8 +1072,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1110,8 +1121,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1180,8 +1192,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1224,8 +1237,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1312,8 +1326,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1363,8 +1378,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1452,8 +1468,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1502,8 +1519,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1579,8 +1597,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1616,8 +1635,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1691,8 +1711,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1719,8 +1740,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1739,14 +1761,17 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] { %p = f32[2,4] parameter(0), sharding={replicated} - ROOT %all-reduce = f32[2,4] all-reduce(%p), replica_groups={{0,1}}, - to_apply=%sum.f32 + ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32, + replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/64, + // Replacing replicated all-reduce is only triggered when there are enough + // replicas (currently > num_partitions * 8). + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/32, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); @@ -1758,7 +1783,7 @@ ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] { auto ar = root->operand(0); auto divisor = root->operand(1)->operand(0); EXPECT_TRUE(ar->channel_id()); - EXPECT_TRUE(divisor->literal().IsAllFloat(4)); + EXPECT_TRUE(divisor->literal().IsAllFloat(2)); } TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) { @@ -1782,8 +1807,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 78924908015..943c8de65dc 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -275,7 +275,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) { - auto module = CreateNewVerifiedModule(); + auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -304,7 +304,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) { - auto module = CreateNewVerifiedModule(); + auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 66ce7d821f0..3d1f21ee8be 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -42,6 +42,7 @@ using absl::string_view; struct TestData { string test_name; string module_string; + int64 replica_count = 1; bool enable_verification = true; }; @@ -1439,7 +1440,8 @@ ENTRY AllReduceWithSubgroups { ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add } -)" +)", +/*replica_count=*/4, }, // all-reduce with constrained layout { @@ -1501,7 +1503,8 @@ ENTRY AllToAllWithSubgroups { ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}} } -)" +)", +/*replica_count=*/4, }, // collective-permute { @@ -1513,7 +1516,8 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } -)" +)", +/*replica_count=*/4 }, // replica-id { @@ -1686,16 +1690,19 @@ class HloParameterizedParserTest void ExpectEqual() { std::unique_ptr module; const string& original = GetParam().module_string; + HloModuleConfig config; + config.set_replica_count(GetParam().replica_count); if (GetParam().enable_verification) { auto verified_module = absl::make_unique( - GetParam().test_name, HloModuleConfig(), + GetParam().test_name, config, /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, ShapeUtil::ByteSizeOfElements); TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original)); module = std::move(verified_module); } else { - TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original)); + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnUnverifiedModule(original, config)); } if (proto_round_trip) { TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index 822b00aecbf..d858d6aa1c7 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -69,8 +69,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/4)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{false, true}); @@ -149,8 +149,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/4)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr analysis, HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true)); @@ -575,8 +575,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/2)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, HloReplicationAnalysis::Run( module.get(), /*cross_partition_spmd=*/false)); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index a8f9f612b0f..0911af10f38 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -210,6 +210,29 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { hlo->ToString()); } } + + // When the channel_id() or use_global_device_ids() is set, device ids in + // ReplicaGroup config no longer only mean replica ids. So we skip the check + // on the replica count. + if (auto channel_instr = DynCast(hlo)) { + if (channel_instr->channel_id()) { + return Status::OK(); + } + } + if (auto all_reduce = DynCast(hlo)) { + if (all_reduce->use_global_device_ids()) { + return Status::OK(); + } + } + + int64 replica_count = hlo->GetModule()->config().replica_count(); + if (!replicas_seen.empty() && replicas_seen.size() != replica_count) { + return InternalError( + "Replica count in HloModuleConfig is %d, but ReplicaGroup config " + "contains %d replicas: %s", + replica_count, replicas_seen.size(), hlo->ToString()); + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8b2b7f6726a..b13e812dc0f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -859,8 +859,17 @@ string ReplicaGroupsStr(std::vector> replica_groups) { return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", ")); } +int64 ReplicaCount(const std::vector>& replica_groups) { + int64 replica_count = 0; + for (auto group : replica_groups) { + replica_count += group.size(); + } + return replica_count; +} + StatusOr> MakeAllReduceComputation( - std::vector> replica_groups) { + std::vector> replica_groups, + absl::optional replica_count = absl::nullopt) { const char* kTemplate = R"( HloModule test add { @@ -872,8 +881,17 @@ StatusOr> MakeAllReduceComputation( p = f32[128]{0} parameter(0) crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS })"; - return ParseAndReturnUnverifiedModule(absl::StrReplaceAll( - kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}})); + + HloModuleConfig config; + if (replica_count) { + config.set_replica_count(*replica_count); + } else { + config.set_replica_count(ReplicaCount(replica_groups)); + } + return ParseAndReturnUnverifiedModule( + absl::StrReplaceAll( + kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}), + config); } TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) { @@ -907,6 +925,21 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) { HasSubstr("Replica 4 is not named")); } +TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) { + TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Replica count in HloModuleConfig is 8, but " + "ReplicaGroup config contains 2 replicas")); +} + +TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + MakeAllReduceComputation({{0, 1}, {2, 3}}, 2)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Replica count in HloModuleConfig is 2, but " + "ReplicaGroup config contains 4 replicas")); +} + StatusOr> MakeAllToAllComputation( std::vector> replica_groups) { const char* kTemplate = R"( @@ -921,8 +954,12 @@ StatusOr> MakeAllToAllComputation( p1 = f32[128]{0} parameter(1) a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS })"; - return ParseAndReturnUnverifiedModule(absl::StrReplaceAll( - kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}})); + HloModuleConfig config; + config.set_replica_count(ReplicaCount(replica_groups)); + return ParseAndReturnUnverifiedModule( + absl::StrReplaceAll( + kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}), + config); } TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) { @@ -966,8 +1003,10 @@ TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) { source_target_pairs={{0,1}, {0,2}, {1,0}} } )"; + HloModuleConfig config; + config.set_replica_count(3); TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnUnverifiedModule(kModuleStr, config)); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), HasSubstr("Source 0 appears more than once")); } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 64d586a9514..8eed609a134 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -117,16 +117,18 @@ std::unique_ptr HloTestBase::CreateNewUnverifiedModule( } std::unique_ptr HloTestBase::CreateNewVerifiedModule( - const string& name) { + const string& name, int64 replica_count) { return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, + name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_, allow_mixed_precision_in_hlo_verifier_, backend().compiler()->ShapeSizeBytesFunction()); } StatusOr> -HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) { - return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()); +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64 replica_count) { + return ParseAndReturnVerifiedModule(hlo_text, + GetModuleConfigForTest(replica_count)); } StatusOr> diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0b1801ebe23..d05776a0cb9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -84,11 +84,11 @@ class HloTestBase : public ::testing::Test { // Like CreateNewUnverifiedModule, except the HloModule returned here runs the // HLO verifier on destruction. std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); + const string& name = TestName(), int64 replica_count = 1); // Parses the given string and returns module as a VerifiedHloModule. StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text); + absl::string_view hlo_text, int64 replica_count = 1); StatusOr> ParseAndReturnVerifiedModule( absl::string_view hlo_text, const HloModuleConfig& config); @@ -130,9 +130,10 @@ class HloTestBase : public ::testing::Test { virtual DebugOptions GetDebugOptionsForTest(); // Gets an HloModuleConfig with options appropriate for tests. - HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); + config.set_replica_count(replica_count); return config; }