From 3c28370a9c66f12f04d6a595d5e46eaf5c460d1f Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Tue, 3 Dec 2019 13:28:24 -0800 Subject: [PATCH] Combine cross-replica / cross-partition AllReduce after SPMD partition PiperOrigin-RevId: 283610192 Change-Id: I801097d159c39d8137457c55906d455e0ee7733d --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/all_reduce_simplifier.cc | 4 +- .../compiler/xla/service/ar_crs_combiner.cc | 52 +- .../compiler/xla/service/ar_crs_combiner.h | 37 +- .../xla/service/ar_crs_combiner_test.cc | 455 +++++++++++++++++- .../xla/service/hlo_replication_analysis.cc | 103 ++-- .../xla/service/hlo_replication_analysis.h | 34 +- .../service/hlo_replication_analysis_test.cc | 131 ++++- 8 files changed, 736 insertions(+), 81 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 23d203850fc..a6300d2dc73 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4197,6 +4197,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:hlo_replication_analysis", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc index b3097b8ff77..541006f04d5 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc @@ -28,7 +28,9 @@ limitations under the License. namespace xla { StatusOr AllReduceSimplifier::Run(HloModule* module) { - TF_ASSIGN_OR_RETURN(auto replication, HloReplicationAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN( + auto replication, + HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); std::vector all_reduces_to_replace; for (auto computation : module->computations()) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index ae39906ef52..06aaad351e6 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -240,7 +241,8 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( /* static */ bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, HloInstruction* i2) { - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, + /*spmd_partition=*/false); auto module = i1->parent()->parent(); CHECK_EQ(module, i2->parent()->parent()); combiner.call_graph_ = CallGraph::Build(module); @@ -363,14 +365,14 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { } } -void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { +Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() { for (auto it : all_reduce_map_) { auto channel_id = it.first; VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " << channel_id << "\n"; auto pairs_vec = it.second; - CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); + TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_); auto instr_0 = pairs_vec[0].ar; for (int i = 1; i < pairs_vec.size(); ++i) { auto instr_i = pairs_vec[i].ar; @@ -393,6 +395,44 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { } } } + return Status::OK(); +} + +Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( + HloModule* module) { + // For SPMD mode, use HloReplicationAnalysis to figure out HLO value + // equivalence across partitions. + TF_ASSIGN_OR_RETURN( + auto replication_analysis, + HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); + + for (auto it : all_reduce_map_) { + auto channel_id = it.first; + VLOG(2) + << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " + << channel_id << "\n"; + auto pairs_vec = it.second; + TF_RET_CHECK(pairs_vec.size() == 1); + auto instr = pairs_vec[0].ar; + auto next = instr->users()[0]; + while (true) { + // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern() + // guarantee that the HLO produces an array. + TF_RET_CHECK(next->shape().IsArray()); + if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) { + all_reduce_map_.erase(channel_id); + VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce " + "channel id: " + << channel_id << "\n"; + break; + } + if (next->IsCrossReplicaAllReduce()) { + break; + } + next = next->users()[0]; + } + } + return Status::OK(); } StatusOr ArCrsCombiner::RewriteGraph() { @@ -460,7 +500,11 @@ StatusOr ArCrsCombiner::Run(HloModule* module) { GroupAllReducesById(module); - KeepProvablyEqualInstructionGroups(); + if (spmd_partition_) { + TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module)); + } else { + TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD()); + } return RewriteGraph(); } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index a85e18d328c..95443c0c74a 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,18 +25,21 @@ limitations under the License. namespace xla { -// When the HLO graph contains a cross-module AllReduce, followed by some simple -// linear operations, followed by a cross-replica AllReduce (also known as -// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an -// efficient AllReduce implementation that fully utilizes the interconnect -// bandwidth. -// Such sequences appear in spatially partitioned models. +// When the HLO graph contains a cross-module AllReduce (N separate AllReduce +// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op +// for SPMD partitioning), followed by some simple linear operations, followed +// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we +// can combine the CMAR and the CRAR, to use an efficient AllReduce +// implementation that fully utilizes the interconnect bandwidth. +// +// Such sequences appear in spatially partitioned models (either MPMD or SPMD). // This pass must run right after spatial partitioning, when the code is still // in a single HLO module. // // The steps are: // 1) Find CMARs followed by simple ops followed by CRARs. -// 2) Group CMARs by channel_id. They must all be rewritten. +// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD +// partitioning, there will only be a single CMAR for each channel_id. // 3) Prove that the CMAR patterns in each core produce the same result. // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the // other operand by the number of spatial partitions. @@ -69,9 +72,11 @@ namespace xla { // class ArCrsCombiner : public HloModulePass { public: - ArCrsCombiner(int num_spatial_partitions, int num_replicas) + ArCrsCombiner(int num_spatial_partitions, int num_replicas, + bool spmd_partition) : num_spatial_partitions_(num_spatial_partitions), - num_replicas_(num_replicas) {} + num_replicas_(num_replicas), + spmd_partition_(spmd_partition) {} absl::string_view name() const override { return "ar-crs-combiner"; } StatusOr Run(HloModule* module) override; @@ -153,7 +158,10 @@ class ArCrsCombiner : public HloModulePass { // Looks at each AllReduce group in all_reduce_map_, and keeps only the // groups for which it's safe to move the AllReduce later in the HLO graph. - void KeepProvablyEqualInstructionGroups(); + Status KeepProvablyEqualInstructionGroupsMPMD(); + + // Same as above, but runs on SPMD partitioned module instead of MPMD. + Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module); // Performs the graph rewrite that eliminates the early AllReduce and turns // the later CRS into an AllReduce. @@ -163,6 +171,15 @@ class ArCrsCombiner : public HloModulePass { int num_replicas_; + // Run this combiner pass assuming the input module is an SPMD partitioned + // module (as opposed to MPMD partitioned). + // + // The main difference between the two w.r.t. this pass is that there would be + // N all-reduce ops for each channel in MPMD mode, whereas there is only 1 + // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO + // equivalence check in SPMD mode. + bool spmd_partition_; + // Map from all-reduce ids to the AR/CRS pairs. absl::flat_hash_map> all_reduce_map_; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index accc0684e8e..609da2c33a0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -452,7 +452,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -464,6 +465,55 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[]) { + %p = bf16[] parameter(0) + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0},{1}}, + channel_id=1, + to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1) + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Convert(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { const char* module_str = R"( HloModule foobar @@ -520,7 +570,8 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -587,7 +638,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -600,6 +652,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.f32 + %multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32) + %all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}}, + to_apply=%sum.f32, sharding={maximal device=0} + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { const char* module_str = R"( HloModule foobar @@ -668,7 +761,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -684,6 +778,55 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0} + %add.1 = f32[] add(%constant.f32, %convert.1) + %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar @@ -750,7 +893,46 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1) + %add.1 = f32[] add(%p, %convert.1) + %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } @@ -810,7 +992,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -884,7 +1067,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -902,6 +1086,50 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + + %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum + %add.11 = f32[] add(%constant.1, %all-reduce.ar.1) + %add.12 = f32[] add(%constant.2, %add.11) + %all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce( + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) { const char* module_str = R"( HloModule foobar @@ -957,7 +1185,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -973,6 +1202,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, + channel_id=1, to_apply=%sum.f32 + %sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1) + %all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Subtract( + op::Divide(op::Constant(), op::Constant()), op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) { const char* module_str = R"( HloModule foobar @@ -1047,7 +1317,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1065,6 +1336,53 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, + to_apply=%sum + %add11 = f32[] add(%ar11, %const1) + %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, + to_apply=%sum + %add12 = f32[] add(%add11, %ar12) + %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, + to_apply=%sum + ROOT %tuple = (f32[]) tuple(%crs1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) { const char* module_str = R"( HloModule foobar @@ -1139,7 +1457,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( @@ -1159,6 +1478,51 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum + %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum + %add11 = f32[] add(%ar12, %const1) + %add12 = f32[] add(%ar11, %add11) + %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum + ROOT %tuple = (f32[]) tuple(%crs1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))))); + + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { const char* module_str = R"( HloModule foobar @@ -1217,7 +1581,45 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, + /*spmd_partition=*/false); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}}, + channel_id=1, to_apply=%sum.bf16 + %convert.1 = f32[] convert(%all-reduce.ar.1) + %all-reduce.1 = f32[] all-reduce(%convert.1), + replica_groups={{0}}, to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, + /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } @@ -1291,7 +1693,36 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/false); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[]) { + %p = bf16[] parameter(0) + %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}}, + to_apply=%sum.f32 + %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}}, + to_apply=%sum.f32 + ROOT %tuple = (f32[]) tuple(%all-reduce.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, + /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc index e11d3920f95..3a896d4a113 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc @@ -35,13 +35,45 @@ namespace { // knowledge in hlo_replication. bool DetermineHloInstructionIsReplicated( const HloInstruction* hlo, const ShapeIndex& index, + bool cross_partition_spmd, const absl::flat_hash_map>& hlo_replication) { + // Returns true if all operands are known to be replicated. + const auto all_operands_replicated = + [&hlo_replication](const HloInstruction* inst) { + for (auto operand : inst->operands()) { + auto operand_it = hlo_replication.find(operand); + if (operand_it == hlo_replication.end() || + !operand_it->second.element({})) { + return false; + } + } + return true; + }; + + if (hlo->IsCrossReplicaAllReduce()) { + if (cross_partition_spmd) { + // Cross-replica all-reduce returns same values across partitions as long + // as its operands are replicated. + return all_operands_replicated(hlo); + } + // Only all-reduce across all cores are replicated, which means there + // is only one subgroup. + return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; + } + if (hlo->IsCrossModuleAllReduce()) { + return cross_partition_spmd; + } if (hlo->HasSideEffectNoRecurse()) { return false; } if (hlo->opcode() == HloOpcode::kReplicaId) { - return false; + // ReplicaId returns the same value for all partitions in each replica. + return cross_partition_spmd; + } + if (hlo->opcode() == HloOpcode::kPartitionId) { + // PartitionId returns the same value for all replicas in each partition. + return !cross_partition_spmd; } auto it = hlo_replication.find(hlo); if (hlo->opcode() == HloOpcode::kParameter) { @@ -55,11 +87,6 @@ bool DetermineHloInstructionIsReplicated( if (hlo->opcode() == HloOpcode::kConstant) { return true; } - if (hlo->opcode() == HloOpcode::kAllReduce) { - // Only all-reduce across all cores are replicated, which means there - // is only one subgroup. - return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; - } if (hlo->IsElementwise() || // hlo->opcode() == HloOpcode::kConcatenate || // @@ -80,14 +107,7 @@ bool DetermineHloInstructionIsReplicated( hlo->opcode() == HloOpcode::kDynamicUpdateSlice || // hlo->opcode() == HloOpcode::kReduceWindow || // hlo->opcode() == HloOpcode::kCopy) { - for (auto operand : hlo->operands()) { - auto operand_it = hlo_replication.find(operand); - if (operand_it == hlo_replication.end() || - !operand_it->second.element({})) { - return false; - } - } - return true; + return all_operands_replicated(hlo); } return false; } @@ -235,8 +255,8 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( ShapeUtil::ForEachSubshape( inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { *shape_tree.mutable_element(index) = - DetermineHloInstructionIsReplicated(inst, index, - hlo_replication_); + DetermineHloInstructionIsReplicated( + inst, index, cross_partition_spmd_, hlo_replication_); return Status::OK(); }); changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); @@ -248,23 +268,39 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( void HloReplicationAnalysis::ComputeHloReplication() { // Add entry parameters to the above sets according to user annotation. + // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas + // SPMD partitioned modules read from HloSharding attributes. auto entry = module_->entry_computation(); for (int i = 0; i < entry->num_parameters(); ++i) { auto param = entry->parameter_instruction(i); ShapeTree shape_tree(param->shape(), false); - const auto& replication = param->parameter_replicated_at_leaf_buffers(); - int leaf_index = 0; - ShapeUtil::ForEachSubshape( - param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + if (cross_partition_spmd_ && param->has_sharding()) { + auto sharding_tree = + param->sharding().AsShapeTree(param->shape()).ValueOrDie(); + ShapeUtil::ForEachSubshape( + param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + return Status::OK(); + } + *shape_tree.mutable_element(index) = + sharding_tree.element(index).IsReplicated(); return Status::OK(); - } - if (replication && replication->at(leaf_index)) { - *shape_tree.mutable_element(index) = true; - } - ++leaf_index; - return Status::OK(); - }); + }); + } else if (!cross_partition_spmd_) { + const auto& replication = param->parameter_replicated_at_leaf_buffers(); + int leaf_index = 0; + ShapeUtil::ForEachSubshape( + param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + return Status::OK(); + } + if (replication && replication->at(leaf_index)) { + *shape_tree.mutable_element(index) = true; + } + ++leaf_index; + return Status::OK(); + }); + } hlo_replication_[param] = std::move(shape_tree); } ComputeHloReplicationOnComputation(entry, @@ -281,17 +317,18 @@ bool HloReplicationAnalysis::HloInstructionIsReplicatedAt( } /* static */ StatusOr> -HloReplicationAnalysis::Run(const HloModule* module) { +HloReplicationAnalysis::Run(const HloModule* module, + bool cross_partition_spmd) { const absl::flat_hash_set empty; - return Run(module, &empty); + return Run(module, cross_partition_spmd, &empty); } /* static */ StatusOr> -HloReplicationAnalysis::Run(const HloModule* module, +HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* loops_known_with_same_iterations) { - auto analysis = absl::WrapUnique( - new HloReplicationAnalysis(module, loops_known_with_same_iterations)); + auto analysis = absl::WrapUnique(new HloReplicationAnalysis( + module, cross_partition_spmd, loops_known_with_same_iterations)); analysis->ComputeHloReplication(); return analysis; } diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.h b/tensorflow/compiler/xla/service/hlo_replication_analysis.h index 3175fc35102..18b2363e454 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.h @@ -25,32 +25,35 @@ limitations under the License. namespace xla { // An HLO pass that determines whether each instruction in the module outputs -// the same value across replicas. It propagates sources of replicated values to +// the same value across replicas or across partitions (depending on the value +// `cross_partition_spmd`). It propagates sources of replicated values to // the rest of the module, where sources include cross-replica-sum, annotated // entry parameters, and constants. class HloReplicationAnalysis { public: // Runs the analysis on module and returns the result or an error. static StatusOr> Run( - const HloModule* module); + const HloModule* module, bool cross_partition_spmd); // Same as above, but the caller can provide additional annotations: a set of // while loops that are known to have the same iteration counts across - // replicas. + // replicas or partitions. static StatusOr> Run( - const HloModule* module, const absl::flat_hash_set* - loops_known_with_same_iterations); + const HloModule* module, bool cross_partition_spmd, + const absl::flat_hash_set* + loops_known_with_same_iterations); // Returns if the HLO instruction outputs the same value (i.e., replicated) at - // the given index across all replicas. + // the given index across all replicas or partitions. bool HloInstructionIsReplicatedAt(const HloInstruction* inst, const ShapeIndex& index) const; private: - HloReplicationAnalysis(const HloModule* module, + HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* loops_known_with_same_iterations) : module_(module), + cross_partition_spmd_(cross_partition_spmd), loops_known_with_same_iterations_(*loops_known_with_same_iterations) {} // Computes hlo_replication_. @@ -63,14 +66,25 @@ class HloReplicationAnalysis { const HloModule* module_; + // If true, run this replication analysis for replicated values across + // partitions (not across replicas) on an SPMD partitioned module. This means + // that HloInstructionIsReplicatedAt() returns true if the value is identical + // across partitions for each replica. The module-level parameter and root + // instructions may have HloSharding attributes that indicate whether values + // are identical across partitions. + // + // If false, HloReplicationAnalysis runs across replicas. + bool cross_partition_spmd_; + // A set of while loops that are known to have the same iteration counts - // across replicas. This is provided by the caller as additional annotations. + // across replicas or partitions. This is provided by the caller as additional + // annotations. const absl::flat_hash_set& loops_known_with_same_iterations_; // A map from each analyzed HLO instruction to a shape tree that represents - // whether the instruction outputs the same value across replicas at each - // shape index. + // whether the instruction outputs the same value across replicas or + // partitions at each shape index. absl::flat_hash_map> hlo_replication_; }; diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index 958e99dedb8..56cc8542ac4 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -42,16 +42,30 @@ sum { ROOT add.2 = f32[] add(a, b) } +sum.u32 { + a = u32[] parameter(0) + b = u32[] parameter(1) + ROOT add.2 = u32[] add(a, b) +} + ENTRY entry { param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0) get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0 get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1 after-all.1 = token[] after-all() + replica-id = u32[] replica-id() + partition-id = u32[] partition-id() infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1) get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0 - dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, to_apply=sum + dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, + to_apply=sum subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, all-reduce) + all-reduce-partitions = u32[] all-reduce(partition-id), channel_id=1, + to_apply=sum.u32 + all-reduce-subgroup = u32[] all-reduce(partition-id), + replica_groups={{0,1},{2,3}}, to_apply=sum.u32 ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract) } )"; @@ -62,7 +76,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{false, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "get-tuple-element.2"), {})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( @@ -77,6 +92,92 @@ ENTRY entry { FindInstruction(module.get(), "subtract"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "add"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "replica-id"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "partition-id"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-partitions"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-subgroup"), {})); +} + +TEST_F(HloReplicationAnalysisTest, NoControlFlowSPMD) { + const string module_str = R"( +HloModule NoControlFlow + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +sum.u32 { + a = u32[] parameter(0) + b = u32[] parameter(1) + ROOT add.2 = u32[] add(a, b) +} + +ENTRY entry { + param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0), + sharding={{maximal device=0}, {replicated}} + get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0 + get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1 + after-all.1 = token[] after-all() + replica-id = u32[] replica-id() + partition-id = u32[] partition-id() + infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1) + get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0 + dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, + to_apply=sum + all-reduce-subgroup = f32[4096,4096]{1,0} all-reduce(dot), + replica_groups={{0,1},{2,3}}, to_apply=sum + all-reduce-partitions = f32[4096,4096]{1,0} all-reduce(get-tuple-element.2), + channel_id=1, to_apply=sum + subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, + all-reduce-partitions) + all-reduce-same-operand = u32[] all-reduce(replica-id), to_apply=sum.u32 + all-reduce-same-operand-subgroup = u32[] all-reduce(replica-id), + replica_groups={{0,1},{2,3}}, to_apply=sum.u32 + all-reduce-different-operand = u32[] all-reduce(partition-id), + to_apply=sum.u32 + ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true)); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.2"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.3"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.5"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "dot"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "subtract"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "replica-id"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "partition-id"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-partitions"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-same-operand"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-same-operand-subgroup"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce-different-operand"), {})); } TEST_F(HloReplicationAnalysisTest, NestedCall) { @@ -111,7 +212,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, false}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "get-tuple-element"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -163,7 +265,8 @@ ENTRY SimpleWhileLoop { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple"), {0})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( @@ -212,7 +315,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "multiply"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -258,7 +362,8 @@ ENTRY WhileLoopDifferentCondition { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "while"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -307,7 +412,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, false, true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple"), {0})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( @@ -371,7 +477,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -409,7 +516,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, false, true, true, true}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple-select"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( @@ -435,7 +543,8 @@ ENTRY entry { param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, false}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, - HloReplicationAnalysis::Run(module.get())); + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "tuple-select"), {0})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(