[XLA] replace all-reduces with a singleton replica_group_size with the operand.
PiperOrigin-RevId: 331155912 Change-Id: I26174aa164b5b1d6684691c4f7dd812ee5485a50
This commit is contained in:
parent
e76be1d662
commit
acf37b5ae4
@ -31,27 +31,7 @@ StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto replication,
|
||||
HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false));
|
||||
std::vector<HloInstruction*> all_reduces_to_replace;
|
||||
for (auto computation : module->computations()) {
|
||||
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
|
||||
if (!inst->shape().IsArray()) {
|
||||
// We currently do not change tuple-shaped all-reduce.
|
||||
// Until XLA will support Token fed AllReduce(), the PyTorch client code
|
||||
// uses a fake data token (constant) which relies on this pass to not
|
||||
// optimize out (being fed within a tuple input).
|
||||
continue;
|
||||
}
|
||||
if (inst->IsCrossReplicaAllReduce() &&
|
||||
replication->HloInstructionIsReplicatedAt(inst->operand(0), {})) {
|
||||
all_reduces_to_replace.push_back(inst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
if (all_reduces_to_replace.empty()) {
|
||||
return changed;
|
||||
}
|
||||
std::vector<std::pair<HloInstruction*, int64>> all_reduces_to_replace;
|
||||
|
||||
// Returns the size of a replica group if all groups have the same size, or -1
|
||||
// if they have different sizes.
|
||||
@ -71,7 +51,40 @@ StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
|
||||
return replica_group_size;
|
||||
};
|
||||
|
||||
for (auto all_reduce : all_reduces_to_replace) {
|
||||
for (auto computation : module->computations()) {
|
||||
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
|
||||
if (!inst->shape().IsArray()) {
|
||||
// We currently do not change tuple-shaped all-reduce.
|
||||
// Until XLA will support Token fed AllReduce(), the PyTorch client code
|
||||
// uses a fake data token (constant) which relies on this pass to not
|
||||
// optimize out (being fed within a tuple input).
|
||||
continue;
|
||||
}
|
||||
if (!inst->IsCrossReplicaAllReduce()) {
|
||||
continue;
|
||||
}
|
||||
int64 group_size = get_replica_group_size(inst);
|
||||
if (group_size == -1) {
|
||||
continue;
|
||||
}
|
||||
if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) ||
|
||||
group_size == 1) {
|
||||
all_reduces_to_replace.push_back({inst, group_size});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
|
||||
for (auto all_reduce_and_group_size : all_reduces_to_replace) {
|
||||
auto all_reduce = all_reduce_and_group_size.first;
|
||||
const int64 replica_group_size = all_reduce_and_group_size.second;
|
||||
if (replica_group_size == 1) {
|
||||
TF_RETURN_IF_ERROR(all_reduce->parent()->ReplaceInstruction(
|
||||
all_reduce, all_reduce->mutable_operand(0)));
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
if (all_reduce->to_apply()->instruction_count() != 3 ||
|
||||
all_reduce->to_apply()->num_parameters() != 2) {
|
||||
continue;
|
||||
@ -79,10 +92,6 @@ StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
|
||||
HloInstruction* replacement;
|
||||
switch (all_reduce->to_apply()->root_instruction()->opcode()) {
|
||||
case HloOpcode::kAdd: {
|
||||
int64 replica_group_size = get_replica_group_size(all_reduce);
|
||||
if (replica_group_size == -1) {
|
||||
continue;
|
||||
}
|
||||
// Create the multiplier:
|
||||
// broadcast(convert_to_matching_type(s32 group size))
|
||||
auto multiplier =
|
||||
|
@ -167,5 +167,30 @@ test {
|
||||
m::Parameter(0), m::AllReduce(m::Parameter(1)))));
|
||||
}
|
||||
|
||||
TEST_F(AllReduceSimplifierTest, TrivialSubgroupAllReduce) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
|
||||
sum {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT add.2 = f32[] add(a, b)
|
||||
}
|
||||
|
||||
|
||||
test {
|
||||
p0 = f32[8,16] parameter(0), parameter_replication={false}
|
||||
ROOT all-reduce = f32[8,16] all-reduce(p0),
|
||||
replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}},
|
||||
to_apply=sum
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
kModuleStr, /*replica_count=*/8));
|
||||
AllReduceSimplifier simplifier(/*replica_count=*/8);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Parameter(0)));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user