Combine cross-replica / cross-partition AllReduce after SPMD partition

PiperOrigin-RevId: 283610192
Change-Id: I801097d159c39d8137457c55906d455e0ee7733d
This commit is contained in:
HyoukJoong Lee 2019-12-03 13:28:24 -08:00 committed by TensorFlower Gardener
parent 1e7a91e26a
commit 3c28370a9c
8 changed files with 736 additions and 81 deletions

View File

@ -4197,6 +4197,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo_replication_analysis",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],

View File

@ -28,7 +28,9 @@ limitations under the License.
namespace xla {
StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto replication, HloReplicationAnalysis::Run(module));
TF_ASSIGN_OR_RETURN(
auto replication,
HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false));
std::vector<HloInstruction*> all_reduces_to_replace;
for (auto computation : module->computations()) {
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -240,7 +241,8 @@ bool ArCrsCombiner::TupleElementsComputeSameValue(
/* static */
bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
HloInstruction* i2) {
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
/*spmd_partition=*/false);
auto module = i1->parent()->parent();
CHECK_EQ(module, i2->parent()->parent());
combiner.call_graph_ = CallGraph::Build(module);
@ -363,14 +365,14 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
}
}
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
for (auto it : all_reduce_map_) {
auto channel_id = it.first;
VLOG(2)
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
<< channel_id << "\n";
auto pairs_vec = it.second;
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
auto instr_0 = pairs_vec[0].ar;
for (int i = 1; i < pairs_vec.size(); ++i) {
auto instr_i = pairs_vec[i].ar;
@ -393,6 +395,44 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
}
}
}
return Status::OK();
}
Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
HloModule* module) {
// For SPMD mode, use HloReplicationAnalysis to figure out HLO value
// equivalence across partitions.
TF_ASSIGN_OR_RETURN(
auto replication_analysis,
HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
for (auto it : all_reduce_map_) {
auto channel_id = it.first;
VLOG(2)
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
<< channel_id << "\n";
auto pairs_vec = it.second;
TF_RET_CHECK(pairs_vec.size() == 1);
auto instr = pairs_vec[0].ar;
auto next = instr->users()[0];
while (true) {
// The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
// guarantee that the HLO produces an array.
TF_RET_CHECK(next->shape().IsArray());
if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
all_reduce_map_.erase(channel_id);
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
"channel id: "
<< channel_id << "\n";
break;
}
if (next->IsCrossReplicaAllReduce()) {
break;
}
next = next->users()[0];
}
}
return Status::OK();
}
StatusOr<bool> ArCrsCombiner::RewriteGraph() {
@ -460,7 +500,11 @@ StatusOr<bool> ArCrsCombiner::Run(HloModule* module) {
GroupAllReducesById(module);
KeepProvablyEqualInstructionGroups();
if (spmd_partition_) {
TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
} else {
TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
}
return RewriteGraph();
}

View File

@ -25,18 +25,21 @@ limitations under the License.
namespace xla {
// When the HLO graph contains a cross-module AllReduce, followed by some simple
// linear operations, followed by a cross-replica AllReduce (also known as
// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an
// efficient AllReduce implementation that fully utilizes the interconnect
// bandwidth.
// Such sequences appear in spatially partitioned models.
// When the HLO graph contains a cross-module AllReduce (N separate AllReduce
// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op
// for SPMD partitioning), followed by some simple linear operations, followed
// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we
// can combine the CMAR and the CRAR, to use an efficient AllReduce
// implementation that fully utilizes the interconnect bandwidth.
//
// Such sequences appear in spatially partitioned models (either MPMD or SPMD).
// This pass must run right after spatial partitioning, when the code is still
// in a single HLO module.
//
// The steps are:
// 1) Find CMARs followed by simple ops followed by CRARs.
// 2) Group CMARs by channel_id. They must all be rewritten.
// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD
// partitioning, there will only be a single CMAR for each channel_id.
// 3) Prove that the CMAR patterns in each core produce the same result.
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
// other operand by the number of spatial partitions.
@ -69,9 +72,11 @@ namespace xla {
//
class ArCrsCombiner : public HloModulePass {
public:
ArCrsCombiner(int num_spatial_partitions, int num_replicas)
ArCrsCombiner(int num_spatial_partitions, int num_replicas,
bool spmd_partition)
: num_spatial_partitions_(num_spatial_partitions),
num_replicas_(num_replicas) {}
num_replicas_(num_replicas),
spmd_partition_(spmd_partition) {}
absl::string_view name() const override { return "ar-crs-combiner"; }
StatusOr<bool> Run(HloModule* module) override;
@ -153,7 +158,10 @@ class ArCrsCombiner : public HloModulePass {
// Looks at each AllReduce group in all_reduce_map_, and keeps only the
// groups for which it's safe to move the AllReduce later in the HLO graph.
void KeepProvablyEqualInstructionGroups();
Status KeepProvablyEqualInstructionGroupsMPMD();
// Same as above, but runs on SPMD partitioned module instead of MPMD.
Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module);
// Performs the graph rewrite that eliminates the early AllReduce and turns
// the later CRS into an AllReduce.
@ -163,6 +171,15 @@ class ArCrsCombiner : public HloModulePass {
int num_replicas_;
// Run this combiner pass assuming the input module is an SPMD partitioned
// module (as opposed to MPMD partitioned).
//
// The main difference between the two w.r.t. this pass is that there would be
// N all-reduce ops for each channel in MPMD mode, whereas there is only 1
// for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO
// equivalence check in SPMD mode.
bool spmd_partition_;
// Map from all-reduce ids to the AR/CRS pairs.
absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_;

View File

@ -452,7 +452,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -464,6 +465,55 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
}
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
%p = bf16[] parameter(0)
%all-reduce.ar.1 = bf16[]
all-reduce(%p),
replica_groups={{0},{1}},
channel_id=1,
to_apply=%sum.bf16
%convert.1 = f32[] convert(%all-reduce.ar.1)
%all-reduce.1 = f32[]
all-reduce(%convert.1),
replica_groups={{0,1}},
to_apply=%sum.f32
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Convert(op::Parameter()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
const char* module_str = R"(
HloModule foobar
@ -520,7 +570,8 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -587,7 +638,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
@ -600,6 +652,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%constant.f32 = f32[] constant(123)
%all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
channel_id=1, to_apply=%sum.f32
%multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32)
%all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=0}
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
const char* module_str = R"(
HloModule foobar
@ -668,7 +761,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
@ -684,6 +778,55 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
}
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%constant.bf16 = bf16[] constant(1)
%constant.f32 = f32[] constant(2)
%all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
channel_id=1, to_apply=%sum.bf16
%convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0}
%add.1 = f32[] add(%constant.f32, %convert.1)
%all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}},
to_apply=%sum.f32
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Add(
op::Divide(op::Constant(), op::Constant()), op::Convert()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
const char* module_str = R"(
HloModule foobar
@ -750,7 +893,46 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
}
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%constant.bf16 = bf16[] constant(1)
%constant.f32.1 = f32[] constant(2)
%all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
channel_id=1, to_apply=%sum.bf16
%convert.1 = f32[] convert(%all-reduce.ar.1)
%add.1 = f32[] add(%p, %convert.1)
%all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}
@ -810,7 +992,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -884,7 +1067,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -902,6 +1086,50 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) {
const char* module_str = R"(
HloModule foobar
%sum (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%constant.1 = f32[] constant(1)
%constant.2 = f32[] constant(2)
%all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
channel_id=1, to_apply=%sum
%add.11 = f32[] add(%constant.1, %all-reduce.ar.1)
%add.12 = f32[] add(%constant.2, %add.11)
%all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(
op::Add(op::Divide(op::Constant(), op::Constant()),
op::Add(op::Divide(op::Constant(), op::Constant()),
op::Parameter())))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
const char* module_str = R"(
HloModule foobar
@ -957,7 +1185,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
@ -973,6 +1202,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%constant.f32 = f32[] constant(123)
%all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
channel_id=1, to_apply=%sum.f32
%sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1)
%all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}},
to_apply=%sum.f32
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Subtract(
op::Divide(op::Constant(), op::Constant()), op::Parameter()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
const char* module_str = R"(
HloModule foobar
@ -1047,7 +1317,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -1065,6 +1336,53 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) {
const char* module_str = R"(
HloModule foobar
%sum (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%const1 = f32[] constant(1)
%const2 = f32[] constant(2)
%ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1,
to_apply=%sum
%add11 = f32[] add(%ar11, %const1)
%ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2,
to_apply=%sum
%add12 = f32[] add(%add11, %ar12)
%crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}},
to_apply=%sum
ROOT %tuple = (f32[]) tuple(%crs1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Add(
op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())),
op::Parameter()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
const char* module_str = R"(
HloModule foobar
@ -1139,7 +1457,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
@ -1159,6 +1478,51 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) {
const char* module_str = R"(
HloModule foobar
%sum (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
%p = f32[] parameter(0)
%const1 = f32[] constant(1)
%const2 = f32[] constant(2)
%ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum
%ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum
%add11 = f32[] add(%ar12, %const1)
%add12 = f32[] add(%ar11, %add11)
%crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum
ROOT %tuple = (f32[]) tuple(%crs1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Add(
op::Parameter(),
op::Add(op::Parameter(),
op::Divide(op::Constant(), op::Constant()))))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
const char* module_str = R"(
HloModule foobar
@ -1217,7 +1581,45 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}
TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
}
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
%p = bf16[] parameter(0)
%constant.bf16 = bf16[] constant(1)
%all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}},
channel_id=1, to_apply=%sum.bf16
%convert.1 = f32[] convert(%all-reduce.ar.1)
%all-reduce.1 = f32[] all-reduce(%convert.1),
replica_groups={{0}}, to_apply=%sum.f32
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}
@ -1291,7 +1693,36 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/false);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}
TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) {
const char* module_str = R"(
HloModule foobar
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
%p = bf16[] parameter(0)
%all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
to_apply=%sum.f32
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
to_apply=%sum.f32
ROOT %tuple = (f32[]) tuple(%all-reduce.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}

View File

@ -35,13 +35,45 @@ namespace {
// knowledge in hlo_replication.
bool DetermineHloInstructionIsReplicated(
const HloInstruction* hlo, const ShapeIndex& index,
bool cross_partition_spmd,
const absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>&
hlo_replication) {
// Returns true if all operands are known to be replicated.
const auto all_operands_replicated =
[&hlo_replication](const HloInstruction* inst) {
for (auto operand : inst->operands()) {
auto operand_it = hlo_replication.find(operand);
if (operand_it == hlo_replication.end() ||
!operand_it->second.element({})) {
return false;
}
}
return true;
};
if (hlo->IsCrossReplicaAllReduce()) {
if (cross_partition_spmd) {
// Cross-replica all-reduce returns same values across partitions as long
// as its operands are replicated.
return all_operands_replicated(hlo);
}
// Only all-reduce across all cores are replicated, which means there
// is only one subgroup.
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
}
if (hlo->IsCrossModuleAllReduce()) {
return cross_partition_spmd;
}
if (hlo->HasSideEffectNoRecurse()) {
return false;
}
if (hlo->opcode() == HloOpcode::kReplicaId) {
return false;
// ReplicaId returns the same value for all partitions in each replica.
return cross_partition_spmd;
}
if (hlo->opcode() == HloOpcode::kPartitionId) {
// PartitionId returns the same value for all replicas in each partition.
return !cross_partition_spmd;
}
auto it = hlo_replication.find(hlo);
if (hlo->opcode() == HloOpcode::kParameter) {
@ -55,11 +87,6 @@ bool DetermineHloInstructionIsReplicated(
if (hlo->opcode() == HloOpcode::kConstant) {
return true;
}
if (hlo->opcode() == HloOpcode::kAllReduce) {
// Only all-reduce across all cores are replicated, which means there
// is only one subgroup.
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
}
if (hlo->IsElementwise() || //
hlo->opcode() == HloOpcode::kConcatenate || //
@ -80,14 +107,7 @@ bool DetermineHloInstructionIsReplicated(
hlo->opcode() == HloOpcode::kDynamicUpdateSlice || //
hlo->opcode() == HloOpcode::kReduceWindow || //
hlo->opcode() == HloOpcode::kCopy) {
for (auto operand : hlo->operands()) {
auto operand_it = hlo_replication.find(operand);
if (operand_it == hlo_replication.end() ||
!operand_it->second.element({})) {
return false;
}
}
return true;
return all_operands_replicated(hlo);
}
return false;
}
@ -235,8 +255,8 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
ShapeUtil::ForEachSubshape(
inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
*shape_tree.mutable_element(index) =
DetermineHloInstructionIsReplicated(inst, index,
hlo_replication_);
DetermineHloInstructionIsReplicated(
inst, index, cross_partition_spmd_, hlo_replication_);
return Status::OK();
});
changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
@ -248,23 +268,39 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
void HloReplicationAnalysis::ComputeHloReplication() {
// Add entry parameters to the above sets according to user annotation.
// Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
// SPMD partitioned modules read from HloSharding attributes.
auto entry = module_->entry_computation();
for (int i = 0; i < entry->num_parameters(); ++i) {
auto param = entry->parameter_instruction(i);
ShapeTree<bool> shape_tree(param->shape(), false);
const auto& replication = param->parameter_replicated_at_leaf_buffers();
int leaf_index = 0;
ShapeUtil::ForEachSubshape(
param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
if (cross_partition_spmd_ && param->has_sharding()) {
auto sharding_tree =
param->sharding().AsShapeTree(param->shape()).ValueOrDie();
ShapeUtil::ForEachSubshape(
param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
return Status::OK();
}
*shape_tree.mutable_element(index) =
sharding_tree.element(index).IsReplicated();
return Status::OK();
}
if (replication && replication->at(leaf_index)) {
*shape_tree.mutable_element(index) = true;
}
++leaf_index;
return Status::OK();
});
});
} else if (!cross_partition_spmd_) {
const auto& replication = param->parameter_replicated_at_leaf_buffers();
int leaf_index = 0;
ShapeUtil::ForEachSubshape(
param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
return Status::OK();
}
if (replication && replication->at(leaf_index)) {
*shape_tree.mutable_element(index) = true;
}
++leaf_index;
return Status::OK();
});
}
hlo_replication_[param] = std::move(shape_tree);
}
ComputeHloReplicationOnComputation(entry,
@ -281,17 +317,18 @@ bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
}
/* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
HloReplicationAnalysis::Run(const HloModule* module) {
HloReplicationAnalysis::Run(const HloModule* module,
bool cross_partition_spmd) {
const absl::flat_hash_set<const HloInstruction*> empty;
return Run(module, &empty);
return Run(module, cross_partition_spmd, &empty);
}
/* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
HloReplicationAnalysis::Run(const HloModule* module,
HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations) {
auto analysis = absl::WrapUnique(
new HloReplicationAnalysis(module, loops_known_with_same_iterations));
auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
module, cross_partition_spmd, loops_known_with_same_iterations));
analysis->ComputeHloReplication();
return analysis;
}

View File

@ -25,32 +25,35 @@ limitations under the License.
namespace xla {
// An HLO pass that determines whether each instruction in the module outputs
// the same value across replicas. It propagates sources of replicated values to
// the same value across replicas or across partitions (depending on the value
// `cross_partition_spmd`). It propagates sources of replicated values to
// the rest of the module, where sources include cross-replica-sum, annotated
// entry parameters, and constants.
class HloReplicationAnalysis {
public:
// Runs the analysis on module and returns the result or an error.
static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
const HloModule* module);
const HloModule* module, bool cross_partition_spmd);
// Same as above, but the caller can provide additional annotations: a set of
// while loops that are known to have the same iteration counts across
// replicas.
// replicas or partitions.
static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
const HloModule* module, const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations);
const HloModule* module, bool cross_partition_spmd,
const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations);
// Returns if the HLO instruction outputs the same value (i.e., replicated) at
// the given index across all replicas.
// the given index across all replicas or partitions.
bool HloInstructionIsReplicatedAt(const HloInstruction* inst,
const ShapeIndex& index) const;
private:
HloReplicationAnalysis(const HloModule* module,
HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd,
const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations)
: module_(module),
cross_partition_spmd_(cross_partition_spmd),
loops_known_with_same_iterations_(*loops_known_with_same_iterations) {}
// Computes hlo_replication_.
@ -63,14 +66,25 @@ class HloReplicationAnalysis {
const HloModule* module_;
// If true, run this replication analysis for replicated values across
// partitions (not across replicas) on an SPMD partitioned module. This means
// that HloInstructionIsReplicatedAt() returns true if the value is identical
// across partitions for each replica. The module-level parameter and root
// instructions may have HloSharding attributes that indicate whether values
// are identical across partitions.
//
// If false, HloReplicationAnalysis runs across replicas.
bool cross_partition_spmd_;
// A set of while loops that are known to have the same iteration counts
// across replicas. This is provided by the caller as additional annotations.
// across replicas or partitions. This is provided by the caller as additional
// annotations.
const absl::flat_hash_set<const HloInstruction*>&
loops_known_with_same_iterations_;
// A map from each analyzed HLO instruction to a shape tree that represents
// whether the instruction outputs the same value across replicas at each
// shape index.
// whether the instruction outputs the same value across replicas or
// partitions at each shape index.
absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>> hlo_replication_;
};

View File

@ -42,16 +42,30 @@ sum {
ROOT add.2 = f32[] add(a, b)
}
sum.u32 {
a = u32[] parameter(0)
b = u32[] parameter(1)
ROOT add.2 = u32[] add(a, b)
}
ENTRY entry {
param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0)
get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0
get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1
after-all.1 = token[] after-all()
replica-id = u32[] replica-id()
partition-id = u32[] partition-id()
infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1)
get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0
dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}
all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, to_apply=sum
dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={},
to_apply=sum
subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, all-reduce)
all-reduce-partitions = u32[] all-reduce(partition-id), channel_id=1,
to_apply=sum.u32
all-reduce-subgroup = u32[] all-reduce(partition-id),
replica_groups={{0,1},{2,3}}, to_apply=sum.u32
ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract)
}
)";
@ -62,7 +76,8 @@ ENTRY entry {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{false, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "get-tuple-element.2"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
@ -77,6 +92,92 @@ ENTRY entry {
FindInstruction(module.get(), "subtract"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "add"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "replica-id"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "partition-id"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce-partitions"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce-subgroup"), {}));
}
TEST_F(HloReplicationAnalysisTest, NoControlFlowSPMD) {
const string module_str = R"(
HloModule NoControlFlow
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
sum.u32 {
a = u32[] parameter(0)
b = u32[] parameter(1)
ROOT add.2 = u32[] add(a, b)
}
ENTRY entry {
param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0),
sharding={{maximal device=0}, {replicated}}
get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0
get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1
after-all.1 = token[] after-all()
replica-id = u32[] replica-id()
partition-id = u32[] partition-id()
infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1)
get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0
dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={},
to_apply=sum
all-reduce-subgroup = f32[4096,4096]{1,0} all-reduce(dot),
replica_groups={{0,1},{2,3}}, to_apply=sum
all-reduce-partitions = f32[4096,4096]{1,0} all-reduce(get-tuple-element.2),
channel_id=1, to_apply=sum
subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3,
all-reduce-partitions)
all-reduce-same-operand = u32[] all-reduce(replica-id), to_apply=sum.u32
all-reduce-same-operand-subgroup = u32[] all-reduce(replica-id),
replica_groups={{0,1},{2,3}}, to_apply=sum.u32
all-reduce-different-operand = u32[] all-reduce(partition-id),
to_apply=sum.u32
ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "get-tuple-element.2"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "get-tuple-element.3"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "get-tuple-element.5"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "dot"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "subtract"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "add"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "replica-id"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "partition-id"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce-partitions"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce-same-operand"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce-same-operand-subgroup"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "all-reduce-different-operand"), {}));
}
TEST_F(HloReplicationAnalysisTest, NestedCall) {
@ -111,7 +212,8 @@ ENTRY entry {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, false});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "get-tuple-element"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@ -163,7 +265,8 @@ ENTRY SimpleWhileLoop {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "tuple"), {0}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
@ -212,7 +315,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "multiply"), {}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@ -258,7 +362,8 @@ ENTRY WhileLoopDifferentCondition {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "while"), {0}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@ -307,7 +412,8 @@ ENTRY entry {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true, true, true, false, true, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "tuple"), {0}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
@ -371,7 +477,8 @@ ENTRY entry {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true, true, true, true, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "tuple"), {0}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@ -409,7 +516,8 @@ ENTRY entry {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, false, true, true, true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "tuple-select"), {0}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
@ -435,7 +543,8 @@ ENTRY entry {
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true, true, true, false});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(module.get()));
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "tuple-select"), {0}));
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(