[TF:XLA] Don't do anything in the AR/CRS combiner if there is only one replica.
PiperOrigin-RevId: 230749050
This commit is contained in:
parent
34c412876d
commit
b0fbe61949
@ -85,9 +85,12 @@ absl::optional<HloInstruction*> MatchesArCrsPattern(
|
|||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return computation_is_addition(next->called_computations()[0])
|
if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
|
||||||
? absl::optional<HloInstruction*>(next)
|
computation_is_addition(next->called_computations()[0])) {
|
||||||
: absl::nullopt;
|
return absl::optional<HloInstruction*>(next);
|
||||||
|
} else {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -1108,5 +1108,68 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
CompareReplicaGroups(replica_groups_before, replica_groups_after);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
|
||||||
|
const char* module_str = R"(
|
||||||
|
HloModule foobar
|
||||||
|
|
||||||
|
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
|
||||||
|
%a = bf16[] parameter(0)
|
||||||
|
%b = bf16[] parameter(1)
|
||||||
|
ROOT %add = bf16[] add(%a, %b)
|
||||||
|
}
|
||||||
|
|
||||||
|
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||||
|
%x = f32[] parameter(0)
|
||||||
|
%y = f32[] parameter(1)
|
||||||
|
ROOT %add = f32[] add(%x, %y)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||||
|
%p = bf16[] parameter(0)
|
||||||
|
%constant.bf16 = bf16[] constant(1)
|
||||||
|
|
||||||
|
%all-reduce.ar.1 = bf16[]
|
||||||
|
all-reduce(%p),
|
||||||
|
replica_groups={{0}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.bf16,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%convert.1 = f32[]
|
||||||
|
convert(%all-reduce.ar.1),
|
||||||
|
sharding={maximal device=0}
|
||||||
|
%all-reduce.1 = f32[]
|
||||||
|
all-reduce(%convert.1),
|
||||||
|
replica_groups={{0}},
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=0}
|
||||||
|
|
||||||
|
%all-reduce.ar.2 = bf16[]
|
||||||
|
all-reduce(%constant.bf16),
|
||||||
|
replica_groups={{0}},
|
||||||
|
all_reduce_id=1,
|
||||||
|
to_apply=%sum.bf16,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%convert.2 = f32[]
|
||||||
|
convert(%all-reduce.ar.2),
|
||||||
|
sharding={maximal device=1}
|
||||||
|
%all-reduce.2 = f32[]
|
||||||
|
all-reduce(%convert.2),
|
||||||
|
replica_groups={{0}},
|
||||||
|
to_apply=%sum.f32,
|
||||||
|
sharding={maximal device=1}
|
||||||
|
|
||||||
|
ROOT %tuple = (f32[], f32[])
|
||||||
|
tuple(%all-reduce.1, %all-reduce.2),
|
||||||
|
sharding={{maximal device=0}, {maximal device=1}}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_str));
|
||||||
|
ArCrsCombiner combiner(2);
|
||||||
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
|
EXPECT_FALSE(changed);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -383,6 +383,15 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const {
|
|||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloAllReduceInstruction::IsNoop() const {
|
||||||
|
for (auto replica_group : replica_groups()) {
|
||||||
|
if (replica_group.replica_ids().size() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return !all_reduce_id();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
|
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& options) const {
|
const HloPrintOptions& options) const {
|
||||||
std::vector<string> result =
|
std::vector<string> result =
|
||||||
|
@ -253,6 +253,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
|||||||
// Returns a serialized representation of this instruction.
|
// Returns a serialized representation of this instruction.
|
||||||
HloInstructionProto ToProto() const override;
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
|
// Returns true if the AllReduce does no communication, so it's equivalent
|
||||||
|
// to a mem copy.
|
||||||
|
bool IsNoop() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<string> ExtraAttributesToStringImpl(
|
std::vector<string> ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& options) const override;
|
const HloPrintOptions& options) const override;
|
||||||
|
Loading…
Reference in New Issue
Block a user