From b0fbe619491151831e4ec3b2a30cc2770f315d9e Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Thu, 24 Jan 2019 10:45:23 -0800 Subject: [PATCH] [TF:XLA] Don't do anything in the AR/CRS combiner if there is only one replica. PiperOrigin-RevId: 230749050 --- .../compiler/xla/service/ar_crs_combiner.cc | 9 ++- .../xla/service/ar_crs_combiner_test.cc | 63 +++++++++++++++++++ .../compiler/xla/service/hlo_instructions.cc | 9 +++ .../compiler/xla/service/hlo_instructions.h | 4 ++ 4 files changed, 82 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index d01393a748e..f8dff6a700c 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -85,9 +85,12 @@ absl::optional MatchesArCrsPattern( return absl::nullopt; } } - return computation_is_addition(next->called_computations()[0]) - ? absl::optional(next) - : absl::nullopt; + if (!Cast(next)->IsNoop() && + computation_is_addition(next->called_computations()[0])) { + return absl::optional(next); + } else { + return absl::nullopt; + } } } // namespace diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index ea28bba959a..5152f0dc884 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -1108,5 +1108,68 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { CompareReplicaGroups(replica_groups_before, replica_groups_after); } +TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 3a0d71dd88b..b01f01ef012 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -383,6 +383,15 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { return proto; } +bool HloAllReduceInstruction::IsNoop() const { + for (auto replica_group : replica_groups()) { + if (replica_group.replica_ids().size() != 1) { + return false; + } + } + return !all_reduce_id(); +} + std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e6111cfb575..1b4a94753cd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -253,6 +253,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns true if the AllReduce does no communication, so it's equivalent + // to a mem copy. + bool IsNoop() const; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override;