Combine cross-replica / cross-partition AllReduce after SPMD partition
PiperOrigin-RevId: 283610192 Change-Id: I801097d159c39d8137457c55906d455e0ee7733d
This commit is contained in:
parent
1e7a91e26a
commit
3c28370a9c
@ -4197,6 +4197,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:hlo_replication_analysis",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -28,7 +28,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(auto replication, HloReplicationAnalysis::Run(module));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto replication,
|
||||
HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false));
|
||||
std::vector<HloInstruction*> all_reduces_to_replace;
|
||||
for (auto computation : module->computations()) {
|
||||
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -240,7 +241,8 @@ bool ArCrsCombiner::TupleElementsComputeSameValue(
|
||||
/* static */
|
||||
bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
|
||||
HloInstruction* i2) {
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||
/*spmd_partition=*/false);
|
||||
auto module = i1->parent()->parent();
|
||||
CHECK_EQ(module, i2->parent()->parent());
|
||||
combiner.call_graph_ = CallGraph::Build(module);
|
||||
@ -363,14 +365,14 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||
Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
|
||||
for (auto it : all_reduce_map_) {
|
||||
auto channel_id = it.first;
|
||||
VLOG(2)
|
||||
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
|
||||
<< channel_id << "\n";
|
||||
auto pairs_vec = it.second;
|
||||
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
|
||||
TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
|
||||
auto instr_0 = pairs_vec[0].ar;
|
||||
for (int i = 1; i < pairs_vec.size(); ++i) {
|
||||
auto instr_i = pairs_vec[i].ar;
|
||||
@ -393,6 +395,44 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
|
||||
HloModule* module) {
|
||||
// For SPMD mode, use HloReplicationAnalysis to figure out HLO value
|
||||
// equivalence across partitions.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto replication_analysis,
|
||||
HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
|
||||
|
||||
for (auto it : all_reduce_map_) {
|
||||
auto channel_id = it.first;
|
||||
VLOG(2)
|
||||
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
|
||||
<< channel_id << "\n";
|
||||
auto pairs_vec = it.second;
|
||||
TF_RET_CHECK(pairs_vec.size() == 1);
|
||||
auto instr = pairs_vec[0].ar;
|
||||
auto next = instr->users()[0];
|
||||
while (true) {
|
||||
// The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
|
||||
// guarantee that the HLO produces an array.
|
||||
TF_RET_CHECK(next->shape().IsArray());
|
||||
if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
|
||||
all_reduce_map_.erase(channel_id);
|
||||
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
|
||||
"channel id: "
|
||||
<< channel_id << "\n";
|
||||
break;
|
||||
}
|
||||
if (next->IsCrossReplicaAllReduce()) {
|
||||
break;
|
||||
}
|
||||
next = next->users()[0];
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
||||
@ -460,7 +500,11 @@ StatusOr<bool> ArCrsCombiner::Run(HloModule* module) {
|
||||
|
||||
GroupAllReducesById(module);
|
||||
|
||||
KeepProvablyEqualInstructionGroups();
|
||||
if (spmd_partition_) {
|
||||
TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
|
||||
}
|
||||
|
||||
return RewriteGraph();
|
||||
}
|
||||
|
@ -25,18 +25,21 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// When the HLO graph contains a cross-module AllReduce, followed by some simple
|
||||
// linear operations, followed by a cross-replica AllReduce (also known as
|
||||
// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an
|
||||
// efficient AllReduce implementation that fully utilizes the interconnect
|
||||
// bandwidth.
|
||||
// Such sequences appear in spatially partitioned models.
|
||||
// When the HLO graph contains a cross-module AllReduce (N separate AllReduce
|
||||
// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op
|
||||
// for SPMD partitioning), followed by some simple linear operations, followed
|
||||
// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we
|
||||
// can combine the CMAR and the CRAR, to use an efficient AllReduce
|
||||
// implementation that fully utilizes the interconnect bandwidth.
|
||||
//
|
||||
// Such sequences appear in spatially partitioned models (either MPMD or SPMD).
|
||||
// This pass must run right after spatial partitioning, when the code is still
|
||||
// in a single HLO module.
|
||||
//
|
||||
// The steps are:
|
||||
// 1) Find CMARs followed by simple ops followed by CRARs.
|
||||
// 2) Group CMARs by channel_id. They must all be rewritten.
|
||||
// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD
|
||||
// partitioning, there will only be a single CMAR for each channel_id.
|
||||
// 3) Prove that the CMAR patterns in each core produce the same result.
|
||||
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
|
||||
// other operand by the number of spatial partitions.
|
||||
@ -69,9 +72,11 @@ namespace xla {
|
||||
//
|
||||
class ArCrsCombiner : public HloModulePass {
|
||||
public:
|
||||
ArCrsCombiner(int num_spatial_partitions, int num_replicas)
|
||||
ArCrsCombiner(int num_spatial_partitions, int num_replicas,
|
||||
bool spmd_partition)
|
||||
: num_spatial_partitions_(num_spatial_partitions),
|
||||
num_replicas_(num_replicas) {}
|
||||
num_replicas_(num_replicas),
|
||||
spmd_partition_(spmd_partition) {}
|
||||
absl::string_view name() const override { return "ar-crs-combiner"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
@ -153,7 +158,10 @@ class ArCrsCombiner : public HloModulePass {
|
||||
|
||||
// Looks at each AllReduce group in all_reduce_map_, and keeps only the
|
||||
// groups for which it's safe to move the AllReduce later in the HLO graph.
|
||||
void KeepProvablyEqualInstructionGroups();
|
||||
Status KeepProvablyEqualInstructionGroupsMPMD();
|
||||
|
||||
// Same as above, but runs on SPMD partitioned module instead of MPMD.
|
||||
Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module);
|
||||
|
||||
// Performs the graph rewrite that eliminates the early AllReduce and turns
|
||||
// the later CRS into an AllReduce.
|
||||
@ -163,6 +171,15 @@ class ArCrsCombiner : public HloModulePass {
|
||||
|
||||
int num_replicas_;
|
||||
|
||||
// Run this combiner pass assuming the input module is an SPMD partitioned
|
||||
// module (as opposed to MPMD partitioned).
|
||||
//
|
||||
// The main difference between the two w.r.t. this pass is that there would be
|
||||
// N all-reduce ops for each channel in MPMD mode, whereas there is only 1
|
||||
// for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO
|
||||
// equivalence check in SPMD mode.
|
||||
bool spmd_partition_;
|
||||
|
||||
// Map from all-reduce ids to the AR/CRS pairs.
|
||||
absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_;
|
||||
|
||||
|
@ -452,7 +452,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -464,6 +465,55 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
}
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
%p = bf16[] parameter(0)
|
||||
%all-reduce.ar.1 = bf16[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16
|
||||
%convert.1 = f32[] convert(%all-reduce.ar.1)
|
||||
%all-reduce.1 = f32[]
|
||||
all-reduce(%convert.1),
|
||||
replica_groups={{0,1}},
|
||||
to_apply=%sum.f32
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(op::Convert(op::Parameter()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -520,7 +570,8 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -587,7 +638,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
@ -600,6 +652,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.f32 = f32[] constant(123)
|
||||
|
||||
%all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
|
||||
channel_id=1, to_apply=%sum.f32
|
||||
%multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32)
|
||||
%all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}},
|
||||
to_apply=%sum.f32, sharding={maximal device=0}
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -668,7 +761,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
@ -684,6 +778,55 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
}
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.bf16 = bf16[] constant(1)
|
||||
%constant.f32 = f32[] constant(2)
|
||||
|
||||
%all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
|
||||
channel_id=1, to_apply=%sum.bf16
|
||||
%convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0}
|
||||
%add.1 = f32[] add(%constant.f32, %convert.1)
|
||||
%all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}},
|
||||
to_apply=%sum.f32
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(op::Add(
|
||||
op::Divide(op::Constant(), op::Constant()), op::Convert()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -750,7 +893,46 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
}
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.bf16 = bf16[] constant(1)
|
||||
%constant.f32.1 = f32[] constant(2)
|
||||
|
||||
%all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
|
||||
channel_id=1, to_apply=%sum.bf16
|
||||
%convert.1 = f32[] convert(%all-reduce.ar.1)
|
||||
%add.1 = f32[] add(%p, %convert.1)
|
||||
%all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
@ -810,7 +992,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -884,7 +1067,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -902,6 +1086,50 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.1 = f32[] constant(1)
|
||||
%constant.2 = f32[] constant(2)
|
||||
|
||||
%all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
|
||||
channel_id=1, to_apply=%sum
|
||||
%add.11 = f32[] add(%constant.1, %all-reduce.ar.1)
|
||||
%add.12 = f32[] add(%constant.2, %add.11)
|
||||
%all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(
|
||||
op::Add(op::Divide(op::Constant(), op::Constant()),
|
||||
op::Add(op::Divide(op::Constant(), op::Constant()),
|
||||
op::Parameter())))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -957,7 +1185,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
@ -973,6 +1202,47 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%constant.f32 = f32[] constant(123)
|
||||
%all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
|
||||
channel_id=1, to_apply=%sum.f32
|
||||
%sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1)
|
||||
%all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}},
|
||||
to_apply=%sum.f32
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(op::Subtract(
|
||||
op::Divide(op::Constant(), op::Constant()), op::Parameter()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -1047,7 +1317,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -1065,6 +1336,53 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%const1 = f32[] constant(1)
|
||||
%const2 = f32[] constant(2)
|
||||
|
||||
%ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1,
|
||||
to_apply=%sum
|
||||
%add11 = f32[] add(%ar11, %const1)
|
||||
%ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2,
|
||||
to_apply=%sum
|
||||
%add12 = f32[] add(%add11, %ar12)
|
||||
%crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}},
|
||||
to_apply=%sum
|
||||
ROOT %tuple = (f32[]) tuple(%crs1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(op::Add(
|
||||
op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())),
|
||||
op::Parameter()))));
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -1139,7 +1457,8 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
@ -1159,6 +1478,51 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: f32[]) -> (f32[]) {
|
||||
%p = f32[] parameter(0)
|
||||
%const1 = f32[] constant(1)
|
||||
%const2 = f32[] constant(2)
|
||||
|
||||
%ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum
|
||||
%ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum
|
||||
%add11 = f32[] add(%ar12, %const1)
|
||||
%add12 = f32[] add(%ar11, %add11)
|
||||
%crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum
|
||||
ROOT %tuple = (f32[]) tuple(%crs1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto crs_before =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_before = crs_before->replica_groups();
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::AllReduce(op::Add(
|
||||
op::Parameter(),
|
||||
op::Add(op::Parameter(),
|
||||
op::Divide(op::Constant(), op::Constant()))))));
|
||||
|
||||
auto crs_after =
|
||||
module->entry_computation()->root_instruction()->operands()[0];
|
||||
auto replica_groups_after = crs_after->replica_groups();
|
||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
@ -1217,7 +1581,45 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||
%a = bf16[] parameter(0)
|
||||
%b = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(%a, %b)
|
||||
}
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
%p = bf16[] parameter(0)
|
||||
%constant.bf16 = bf16[] constant(1)
|
||||
|
||||
%all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}},
|
||||
channel_id=1, to_apply=%sum.bf16
|
||||
%convert.1 = f32[] convert(%all-reduce.ar.1)
|
||||
%all-reduce.1 = f32[] all-reduce(%convert.1),
|
||||
replica_groups={{0}}, to_apply=%sum.f32
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
@ -1291,7 +1693,36 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2);
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/false);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) {
|
||||
const char* module_str = R"(
|
||||
HloModule foobar
|
||||
|
||||
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
||||
%p = bf16[] parameter(0)
|
||||
%all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
|
||||
to_apply=%sum.f32
|
||||
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
|
||||
to_apply=%sum.f32
|
||||
ROOT %tuple = (f32[]) tuple(%all-reduce.2)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2,
|
||||
/*spmd_partition=*/true);
|
||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
@ -35,13 +35,45 @@ namespace {
|
||||
// knowledge in hlo_replication.
|
||||
bool DetermineHloInstructionIsReplicated(
|
||||
const HloInstruction* hlo, const ShapeIndex& index,
|
||||
bool cross_partition_spmd,
|
||||
const absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>&
|
||||
hlo_replication) {
|
||||
// Returns true if all operands are known to be replicated.
|
||||
const auto all_operands_replicated =
|
||||
[&hlo_replication](const HloInstruction* inst) {
|
||||
for (auto operand : inst->operands()) {
|
||||
auto operand_it = hlo_replication.find(operand);
|
||||
if (operand_it == hlo_replication.end() ||
|
||||
!operand_it->second.element({})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if (hlo->IsCrossReplicaAllReduce()) {
|
||||
if (cross_partition_spmd) {
|
||||
// Cross-replica all-reduce returns same values across partitions as long
|
||||
// as its operands are replicated.
|
||||
return all_operands_replicated(hlo);
|
||||
}
|
||||
// Only all-reduce across all cores are replicated, which means there
|
||||
// is only one subgroup.
|
||||
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
|
||||
}
|
||||
if (hlo->IsCrossModuleAllReduce()) {
|
||||
return cross_partition_spmd;
|
||||
}
|
||||
if (hlo->HasSideEffectNoRecurse()) {
|
||||
return false;
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kReplicaId) {
|
||||
return false;
|
||||
// ReplicaId returns the same value for all partitions in each replica.
|
||||
return cross_partition_spmd;
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kPartitionId) {
|
||||
// PartitionId returns the same value for all replicas in each partition.
|
||||
return !cross_partition_spmd;
|
||||
}
|
||||
auto it = hlo_replication.find(hlo);
|
||||
if (hlo->opcode() == HloOpcode::kParameter) {
|
||||
@ -55,11 +87,6 @@ bool DetermineHloInstructionIsReplicated(
|
||||
if (hlo->opcode() == HloOpcode::kConstant) {
|
||||
return true;
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kAllReduce) {
|
||||
// Only all-reduce across all cores are replicated, which means there
|
||||
// is only one subgroup.
|
||||
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
|
||||
}
|
||||
|
||||
if (hlo->IsElementwise() || //
|
||||
hlo->opcode() == HloOpcode::kConcatenate || //
|
||||
@ -80,14 +107,7 @@ bool DetermineHloInstructionIsReplicated(
|
||||
hlo->opcode() == HloOpcode::kDynamicUpdateSlice || //
|
||||
hlo->opcode() == HloOpcode::kReduceWindow || //
|
||||
hlo->opcode() == HloOpcode::kCopy) {
|
||||
for (auto operand : hlo->operands()) {
|
||||
auto operand_it = hlo_replication.find(operand);
|
||||
if (operand_it == hlo_replication.end() ||
|
||||
!operand_it->second.element({})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return all_operands_replicated(hlo);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -235,8 +255,8 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
|
||||
ShapeUtil::ForEachSubshape(
|
||||
inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
*shape_tree.mutable_element(index) =
|
||||
DetermineHloInstructionIsReplicated(inst, index,
|
||||
hlo_replication_);
|
||||
DetermineHloInstructionIsReplicated(
|
||||
inst, index, cross_partition_spmd_, hlo_replication_);
|
||||
return Status::OK();
|
||||
});
|
||||
changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
|
||||
@ -248,10 +268,25 @@ bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
|
||||
|
||||
void HloReplicationAnalysis::ComputeHloReplication() {
|
||||
// Add entry parameters to the above sets according to user annotation.
|
||||
// Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
|
||||
// SPMD partitioned modules read from HloSharding attributes.
|
||||
auto entry = module_->entry_computation();
|
||||
for (int i = 0; i < entry->num_parameters(); ++i) {
|
||||
auto param = entry->parameter_instruction(i);
|
||||
ShapeTree<bool> shape_tree(param->shape(), false);
|
||||
if (cross_partition_spmd_ && param->has_sharding()) {
|
||||
auto sharding_tree =
|
||||
param->sharding().AsShapeTree(param->shape()).ValueOrDie();
|
||||
ShapeUtil::ForEachSubshape(
|
||||
param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
|
||||
return Status::OK();
|
||||
}
|
||||
*shape_tree.mutable_element(index) =
|
||||
sharding_tree.element(index).IsReplicated();
|
||||
return Status::OK();
|
||||
});
|
||||
} else if (!cross_partition_spmd_) {
|
||||
const auto& replication = param->parameter_replicated_at_leaf_buffers();
|
||||
int leaf_index = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
@ -265,6 +300,7 @@ void HloReplicationAnalysis::ComputeHloReplication() {
|
||||
++leaf_index;
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
hlo_replication_[param] = std::move(shape_tree);
|
||||
}
|
||||
ComputeHloReplicationOnComputation(entry,
|
||||
@ -281,17 +317,18 @@ bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
|
||||
HloReplicationAnalysis::Run(const HloModule* module) {
|
||||
HloReplicationAnalysis::Run(const HloModule* module,
|
||||
bool cross_partition_spmd) {
|
||||
const absl::flat_hash_set<const HloInstruction*> empty;
|
||||
return Run(module, &empty);
|
||||
return Run(module, cross_partition_spmd, &empty);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
|
||||
HloReplicationAnalysis::Run(const HloModule* module,
|
||||
HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
|
||||
const absl::flat_hash_set<const HloInstruction*>*
|
||||
loops_known_with_same_iterations) {
|
||||
auto analysis = absl::WrapUnique(
|
||||
new HloReplicationAnalysis(module, loops_known_with_same_iterations));
|
||||
auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
|
||||
module, cross_partition_spmd, loops_known_with_same_iterations));
|
||||
analysis->ComputeHloReplication();
|
||||
return analysis;
|
||||
}
|
||||
|
@ -25,32 +25,35 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// An HLO pass that determines whether each instruction in the module outputs
|
||||
// the same value across replicas. It propagates sources of replicated values to
|
||||
// the same value across replicas or across partitions (depending on the value
|
||||
// `cross_partition_spmd`). It propagates sources of replicated values to
|
||||
// the rest of the module, where sources include cross-replica-sum, annotated
|
||||
// entry parameters, and constants.
|
||||
class HloReplicationAnalysis {
|
||||
public:
|
||||
// Runs the analysis on module and returns the result or an error.
|
||||
static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
|
||||
const HloModule* module);
|
||||
const HloModule* module, bool cross_partition_spmd);
|
||||
|
||||
// Same as above, but the caller can provide additional annotations: a set of
|
||||
// while loops that are known to have the same iteration counts across
|
||||
// replicas.
|
||||
// replicas or partitions.
|
||||
static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
|
||||
const HloModule* module, const absl::flat_hash_set<const HloInstruction*>*
|
||||
const HloModule* module, bool cross_partition_spmd,
|
||||
const absl::flat_hash_set<const HloInstruction*>*
|
||||
loops_known_with_same_iterations);
|
||||
|
||||
// Returns if the HLO instruction outputs the same value (i.e., replicated) at
|
||||
// the given index across all replicas.
|
||||
// the given index across all replicas or partitions.
|
||||
bool HloInstructionIsReplicatedAt(const HloInstruction* inst,
|
||||
const ShapeIndex& index) const;
|
||||
|
||||
private:
|
||||
HloReplicationAnalysis(const HloModule* module,
|
||||
HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd,
|
||||
const absl::flat_hash_set<const HloInstruction*>*
|
||||
loops_known_with_same_iterations)
|
||||
: module_(module),
|
||||
cross_partition_spmd_(cross_partition_spmd),
|
||||
loops_known_with_same_iterations_(*loops_known_with_same_iterations) {}
|
||||
|
||||
// Computes hlo_replication_.
|
||||
@ -63,14 +66,25 @@ class HloReplicationAnalysis {
|
||||
|
||||
const HloModule* module_;
|
||||
|
||||
// If true, run this replication analysis for replicated values across
|
||||
// partitions (not across replicas) on an SPMD partitioned module. This means
|
||||
// that HloInstructionIsReplicatedAt() returns true if the value is identical
|
||||
// across partitions for each replica. The module-level parameter and root
|
||||
// instructions may have HloSharding attributes that indicate whether values
|
||||
// are identical across partitions.
|
||||
//
|
||||
// If false, HloReplicationAnalysis runs across replicas.
|
||||
bool cross_partition_spmd_;
|
||||
|
||||
// A set of while loops that are known to have the same iteration counts
|
||||
// across replicas. This is provided by the caller as additional annotations.
|
||||
// across replicas or partitions. This is provided by the caller as additional
|
||||
// annotations.
|
||||
const absl::flat_hash_set<const HloInstruction*>&
|
||||
loops_known_with_same_iterations_;
|
||||
|
||||
// A map from each analyzed HLO instruction to a shape tree that represents
|
||||
// whether the instruction outputs the same value across replicas at each
|
||||
// shape index.
|
||||
// whether the instruction outputs the same value across replicas or
|
||||
// partitions at each shape index.
|
||||
absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>> hlo_replication_;
|
||||
};
|
||||
|
||||
|
@ -42,16 +42,30 @@ sum {
|
||||
ROOT add.2 = f32[] add(a, b)
|
||||
}
|
||||
|
||||
sum.u32 {
|
||||
a = u32[] parameter(0)
|
||||
b = u32[] parameter(1)
|
||||
ROOT add.2 = u32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0)
|
||||
get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0
|
||||
get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1
|
||||
after-all.1 = token[] after-all()
|
||||
replica-id = u32[] replica-id()
|
||||
partition-id = u32[] partition-id()
|
||||
infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1)
|
||||
get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0
|
||||
dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, to_apply=sum
|
||||
dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={},
|
||||
to_apply=sum
|
||||
subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, all-reduce)
|
||||
all-reduce-partitions = u32[] all-reduce(partition-id), channel_id=1,
|
||||
to_apply=sum.u32
|
||||
all-reduce-subgroup = u32[] all-reduce(partition-id),
|
||||
replica_groups={{0,1},{2,3}}, to_apply=sum.u32
|
||||
ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract)
|
||||
}
|
||||
)";
|
||||
@ -62,7 +76,8 @@ ENTRY entry {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{false, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "get-tuple-element.2"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -77,6 +92,92 @@ ENTRY entry {
|
||||
FindInstruction(module.get(), "subtract"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "add"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "replica-id"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "partition-id"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce-partitions"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce-subgroup"), {}));
|
||||
}
|
||||
|
||||
TEST_F(HloReplicationAnalysisTest, NoControlFlowSPMD) {
|
||||
const string module_str = R"(
|
||||
HloModule NoControlFlow
|
||||
|
||||
sum {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT add.2 = f32[] add(a, b)
|
||||
}
|
||||
|
||||
sum.u32 {
|
||||
a = u32[] parameter(0)
|
||||
b = u32[] parameter(1)
|
||||
ROOT add.2 = u32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0),
|
||||
sharding={{maximal device=0}, {replicated}}
|
||||
get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0
|
||||
get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1
|
||||
after-all.1 = token[] after-all()
|
||||
replica-id = u32[] replica-id()
|
||||
partition-id = u32[] partition-id()
|
||||
infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1)
|
||||
get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0
|
||||
dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3),
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={},
|
||||
to_apply=sum
|
||||
all-reduce-subgroup = f32[4096,4096]{1,0} all-reduce(dot),
|
||||
replica_groups={{0,1},{2,3}}, to_apply=sum
|
||||
all-reduce-partitions = f32[4096,4096]{1,0} all-reduce(get-tuple-element.2),
|
||||
channel_id=1, to_apply=sum
|
||||
subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3,
|
||||
all-reduce-partitions)
|
||||
all-reduce-same-operand = u32[] all-reduce(replica-id), to_apply=sum.u32
|
||||
all-reduce-same-operand-subgroup = u32[] all-reduce(replica-id),
|
||||
replica_groups={{0,1},{2,3}}, to_apply=sum.u32
|
||||
all-reduce-different-operand = u32[] all-reduce(partition-id),
|
||||
to_apply=sum.u32
|
||||
ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "get-tuple-element.2"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "get-tuple-element.3"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "get-tuple-element.5"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "dot"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "subtract"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "add"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "replica-id"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "partition-id"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce-partitions"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce-same-operand"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce-same-operand-subgroup"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "all-reduce-different-operand"), {}));
|
||||
}
|
||||
|
||||
TEST_F(HloReplicationAnalysisTest, NestedCall) {
|
||||
@ -111,7 +212,8 @@ ENTRY entry {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, false});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "get-tuple-element"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -163,7 +265,8 @@ ENTRY SimpleWhileLoop {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "tuple"), {0}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -212,7 +315,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "multiply"), {}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -258,7 +362,8 @@ ENTRY WhileLoopDifferentCondition {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "while"), {0}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -307,7 +412,8 @@ ENTRY entry {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true, true, true, false, true, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "tuple"), {0}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -371,7 +477,8 @@ ENTRY entry {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true, true, true, true, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "tuple"), {0}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -409,7 +516,8 @@ ENTRY entry {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, false, true, true, true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "tuple-select"), {0}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
@ -435,7 +543,8 @@ ENTRY entry {
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true, true, true, false});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(module.get()));
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "tuple-select"), {0}));
|
||||
EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt(
|
||||
|
Loading…
Reference in New Issue
Block a user