[TF:XLA] Don't do anything in the AR/CRS combiner if there is only one replica.

PiperOrigin-RevId: 230749050
This commit is contained in:
Dimitris Vardoulakis 2019-01-24 10:45:23 -08:00 committed by TensorFlower Gardener
parent 34c412876d
commit b0fbe61949
4 changed files with 82 additions and 3 deletions

View File

@ -85,9 +85,12 @@ absl::optional<HloInstruction*> MatchesArCrsPattern(
return absl::nullopt; return absl::nullopt;
} }
} }
return computation_is_addition(next->called_computations()[0]) if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
? absl::optional<HloInstruction*>(next) computation_is_addition(next->called_computations()[0])) {
: absl::nullopt; return absl::optional<HloInstruction*>(next);
} else {
return absl::nullopt;
}
} }
} // namespace } // namespace

View File

@ -1108,5 +1108,68 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
CompareReplicaGroups(replica_groups_before, replica_groups_after); CompareReplicaGroups(replica_groups_before, replica_groups_after);
} }
TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
const char* module_str = R"(
HloModule foobar
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
}
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%p = bf16[] parameter(0)
%constant.bf16 = bf16[] constant(1)
%all-reduce.ar.1 = bf16[]
all-reduce(%p),
replica_groups={{0}},
all_reduce_id=1,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[]
convert(%all-reduce.ar.1),
sharding={maximal device=0}
%all-reduce.1 = f32[]
all-reduce(%convert.1),
replica_groups={{0}},
to_apply=%sum.f32,
sharding={maximal device=0}
%all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0}},
all_reduce_id=1,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[]
convert(%all-reduce.ar.2),
sharding={maximal device=1}
%all-reduce.2 = f32[]
all-reduce(%convert.2),
replica_groups={{0}},
to_apply=%sum.f32,
sharding={maximal device=1}
ROOT %tuple = (f32[], f32[])
tuple(%all-reduce.1, %all-reduce.2),
sharding={{maximal device=0}, {maximal device=1}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
ArCrsCombiner combiner(2);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_FALSE(changed);
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -383,6 +383,15 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const {
return proto; return proto;
} }
bool HloAllReduceInstruction::IsNoop() const {
for (auto replica_group : replica_groups()) {
if (replica_group.replica_ids().size() != 1) {
return false;
}
}
return !all_reduce_id();
}
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const { const HloPrintOptions& options) const {
std::vector<string> result = std::vector<string> result =

View File

@ -253,6 +253,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
// Returns a serialized representation of this instruction. // Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override; HloInstructionProto ToProto() const override;
// Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy.
bool IsNoop() const;
private: private:
std::vector<string> ExtraAttributesToStringImpl( std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override; const HloPrintOptions& options) const override;