diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 9460cc55e10..6d3045d8fe4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -95,7 +95,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && - consumer->ReusesOperandElements(operand_index)) { + ReusesOperandElements(consumer, operand_index)) { VLOG(2) << "Fusion is not profitable."; return false; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 4680f072140..76d7f641cb2 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -65,9 +65,8 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). - if (producer->opcode() != HloOpcode::kFusion && - consumer->ReusesOperandElements(operand_index) && - is_expensive(*producer)) { + if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && + ReusesOperandElements(consumer, operand_index)) { VLOG(4) << "Do not fuse simple, expensive producer " << producer->name() << " and consumer which reuses operand elements."; return false; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 2085b1ea4d0..7f21a4dcd1f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -711,4 +711,23 @@ HloInstruction::FusionKind InstructionFusion::ChooseKind( return HloInstruction::FusionKind::kLoop; } +bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer, + int64 operand_index) { + auto operand = consumer->operand(operand_index); + auto it = reused_fusion_operands_.find(consumer); + if (it != reused_fusion_operands_.end() && it->second.contains(operand)) { + return true; + } + bool reuses = consumer->ReusesOperandElements(operand_index); + // If a parameter was reused, we can cache this information. Fusion + // computations only ever grow, so it becomes more likely that a parameter is + // reused, but a reused parameter will never become *not* reused. + if (reuses) { + // We cache the operand corresponding to the fusion parameter, because the + // parameter pointers would be invalidated after the next fusion. + reused_fusion_operands_[consumer].insert(operand); + } + return reuses; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 90d9da48e33..d51bf700371 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,4 +1,3 @@ -#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +19,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -138,6 +139,11 @@ class InstructionFusion : public HloModulePass { return config_collection_mode_; } + // Returns whether 'consumer' may reuse elements of its `operand_index`th + // operand. + bool ReusesOperandElements(const HloInstruction* consumer, + int64 operand_index); + private: // The set of producers whose consumers we cannot fuse into. using HloInstructionSet = std::unordered_set; @@ -172,6 +178,11 @@ class InstructionFusion : public HloModulePass { // Configuration mode. FusionConfigCollection config_collection_mode_; + // Caches which operands are reused inside fusion computations. + absl::flat_hash_map> + reused_fusion_operands_; + TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion); };