diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b2ec656a2ba..b3f8a9590f3 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -913,6 +913,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 4ee78b0874c..fa73ac261f8 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -148,6 +149,16 @@ std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates( << " would be too large of a fusion."; continue; } + // Make sure the emitter can codegen the fusion op efficiently. We currently + // can have exponential time/memory requirements for emitting certain fusion + // ops, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once fixed in the emitter. + if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) { + VLOG(3) << "Fusion of " << producer->name() << " into " + << consumer->name() + << " would result in overly large code duplication."; + continue; + } fusion_candidates.push_back(consumer); } return fusion_candidates; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 283eb30d15f..6cb66290a9a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -909,5 +909,165 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { EXPECT_EQ(2, CountMultiOutputFusions(module.get())); } +TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule module + +and.reduce_sub_computation { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) +} + +fused_computation.1 { + param_4.658 = f32[2,20,256]{2,0,1} parameter(4) + slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]} + constant.6847 = s32[] constant(0) + broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={} + param_9.415 = s32[3]{0} parameter(9) + compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE + constant.6846 = pred[] constant(true) + reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={} + param_5.528 = f32[2,512]{1,0} parameter(5) + slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]} + bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384) + constant.5418 = f32[] constant(0) + broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={} + select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227) + add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173) + param_0.299 = s32[] parameter(0) + constant.5157 = s32[] constant(11) + dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299) + slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]} + constant.6800 = s32[] constant(0) + broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={} + param_8.484 = s32[3]{0} parameter(8) + compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE + constant.6798 = pred[] constant(true) + reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={} + param_3.1169 = f32[2,512]{1,0} parameter(3) + slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]} + bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382) + select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227) + add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172) + constant.5154 = s32[] constant(10) + dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299) + slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]} + constant.6794 = s32[] constant(0) + broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={} + param_7.478 = s32[3]{0} parameter(7) + compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE + constant.6793 = pred[] constant(true) + reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={} + param_2.1685 = f32[2,512]{1,0} parameter(2) + slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]} + bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380) + select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227) + add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171) + constant.5153 = s32[] constant(9) + dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299) + slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]} + constant.6788 = s32[] constant(0) + broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={} + param_6.495 = s32[3]{0} parameter(6) + compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE + constant.6786 = pred[] constant(true) + reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={} + param_1.1408 = f32[2,512]{1,0} parameter(1) + slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]} + bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378) + select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227) + add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170) + constant.5152 = s32[] constant(8) + ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299) +} + +fused_computation.2 { + param_4.655 = f32[2,20,256]{2,0,1} parameter(4) + slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]} + param_6.483 = pred[] parameter(6) + broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={} + param_5.525 = f32[2,512]{1,0} parameter(5) + slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]} + bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368) + constant.5415 = f32[] constant(0) + broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={} + select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225) + add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161) + param_0.265 = s32[] parameter(0) + constant.5151 = s32[] constant(7) + dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265) + slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]} + constant.6782 = s32[] constant(0) + broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={} + param_9.391 = s32[3]{0} parameter(9) + compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE + constant.6781 = pred[] constant(true) + reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={} + param_3.1167 = f32[2,512]{1,0} parameter(3) + slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]} + bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366) + select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225) + add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160) + constant.5150 = s32[] constant(6) + dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265) + slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]} + constant.6776 = s32[] constant(0) + broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={} + param_8.464 = s32[3]{0} parameter(8) + compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE + constant.6775 = pred[] constant(true) + reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={} + param_2.1684 = f32[2,512]{1,0} parameter(2) + slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]} + bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364) + select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225) + add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159) + constant.5149 = s32[] constant(5) + dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265) + slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]} + constant.6770 = s32[] constant(0) + broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={} + param_7.458 = s32[3]{0} parameter(7) + compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE + constant.6769 = pred[] constant(true) + reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={} + param_1.1405 = f32[2,512]{1,0} parameter(1) + slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]} + bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362) + select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225) + add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158) + constant.5148 = s32[] constant(4) + ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265) +} + +ENTRY main { + param_0.0 = s32[] parameter(0) + param_1.0 = f32[2,512]{1,0} parameter(1) + param_2.0 = f32[2,512]{1,0} parameter(2) + param_3.0 = f32[2,512]{1,0} parameter(3) + param_4.0 = f32[2,20,256]{2,1,0} parameter(4) + param_5.0 = f32[2,512]{1,0} parameter(5) + param_6.0 = s32[3]{0} parameter(6) + param_7.0 = s32[3]{0} parameter(7) + param_8.0 = s32[3]{0} parameter(8) + param_9.0 = s32[3]{0} parameter(9) + fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1 + param_10 = pred[] parameter(10) + fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2 + ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2) +} + )") + .ValueOrDie(); + EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla