[XLA] replace all-reduces with a singleton replica_group_size with the operand.

PiperOrigin-RevId: 331155912
Change-Id: I26174aa164b5b1d6684691c4f7dd812ee5485a50
This commit is contained in:
Blake Hechtman 2020-09-11 08:50:20 -07:00 committed by TensorFlower Gardener
parent e76be1d662
commit acf37b5ae4
2 changed files with 60 additions and 26 deletions

View File

@ -31,27 +31,7 @@ StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(
auto replication,
HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false));
std::vector<HloInstruction*> all_reduces_to_replace;
for (auto computation : module->computations()) {
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
if (!inst->shape().IsArray()) {
// We currently do not change tuple-shaped all-reduce.
// Until XLA will support Token fed AllReduce(), the PyTorch client code
// uses a fake data token (constant) which relies on this pass to not
// optimize out (being fed within a tuple input).
continue;
}
if (inst->IsCrossReplicaAllReduce() &&
replication->HloInstructionIsReplicatedAt(inst->operand(0), {})) {
all_reduces_to_replace.push_back(inst);
}
}
}
bool changed = false;
if (all_reduces_to_replace.empty()) {
return changed;
}
std::vector<std::pair<HloInstruction*, int64>> all_reduces_to_replace;
// Returns the size of a replica group if all groups have the same size, or -1
// if they have different sizes.
@ -71,7 +51,40 @@ StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
return replica_group_size;
};
for (auto all_reduce : all_reduces_to_replace) {
for (auto computation : module->computations()) {
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
if (!inst->shape().IsArray()) {
// We currently do not change tuple-shaped all-reduce.
// Until XLA will support Token fed AllReduce(), the PyTorch client code
// uses a fake data token (constant) which relies on this pass to not
// optimize out (being fed within a tuple input).
continue;
}
if (!inst->IsCrossReplicaAllReduce()) {
continue;
}
int64 group_size = get_replica_group_size(inst);
if (group_size == -1) {
continue;
}
if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) ||
group_size == 1) {
all_reduces_to_replace.push_back({inst, group_size});
}
}
}
bool changed = false;
for (auto all_reduce_and_group_size : all_reduces_to_replace) {
auto all_reduce = all_reduce_and_group_size.first;
const int64 replica_group_size = all_reduce_and_group_size.second;
if (replica_group_size == 1) {
TF_RETURN_IF_ERROR(all_reduce->parent()->ReplaceInstruction(
all_reduce, all_reduce->mutable_operand(0)));
changed = true;
continue;
}
if (all_reduce->to_apply()->instruction_count() != 3 ||
all_reduce->to_apply()->num_parameters() != 2) {
continue;
@ -79,10 +92,6 @@ StatusOr<bool> AllReduceSimplifier::Run(HloModule* module) {
HloInstruction* replacement;
switch (all_reduce->to_apply()->root_instruction()->opcode()) {
case HloOpcode::kAdd: {
int64 replica_group_size = get_replica_group_size(all_reduce);
if (replica_group_size == -1) {
continue;
}
// Create the multiplier:
// broadcast(convert_to_matching_type(s32 group size))
auto multiplier =

View File

@ -167,5 +167,30 @@ test {
m::Parameter(0), m::AllReduce(m::Parameter(1)))));
}
TEST_F(AllReduceSimplifierTest, TrivialSubgroupAllReduce) {
const char* kModuleStr = R"(
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
test {
p0 = f32[8,16] parameter(0), parameter_replication={false}
ROOT all-reduce = f32[8,16] all-reduce(p0),
replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}},
to_apply=sum
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kModuleStr, /*replica_count=*/8));
AllReduceSimplifier simplifier(/*replica_count=*/8);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}
} // namespace
} // namespace xla