[XLA:GPU] Do not create multi-output fusions that would cause exponential compile behavior (time and size).
PiperOrigin-RevId: 337095307 Change-Id: I5fd2d2c169ce80b26f99c62749eb19843d42be07
This commit is contained in:
parent
43e3def545
commit
379bcfa723
@ -913,6 +913,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_reachability",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -148,6 +149,16 @@ std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
|
||||
<< " would be too large of a fusion.";
|
||||
continue;
|
||||
}
|
||||
// Make sure the emitter can codegen the fusion op efficiently. We currently
|
||||
// can have exponential time/memory requirements for emitting certain fusion
|
||||
// ops, in which case we don't want to fuse.
|
||||
// TODO(b/119692968): Remove this once fixed in the emitter.
|
||||
if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) {
|
||||
VLOG(3) << "Fusion of " << producer->name() << " into "
|
||||
<< consumer->name()
|
||||
<< " would result in overly large code duplication.";
|
||||
continue;
|
||||
}
|
||||
fusion_candidates.push_back(consumer);
|
||||
}
|
||||
return fusion_candidates;
|
||||
|
@ -909,5 +909,165 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
|
||||
EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule module
|
||||
|
||||
and.reduce_sub_computation {
|
||||
x = pred[] parameter(0)
|
||||
y = pred[] parameter(1)
|
||||
ROOT and = pred[] and(x, y)
|
||||
}
|
||||
|
||||
fused_computation.1 {
|
||||
param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
|
||||
slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
|
||||
constant.6847 = s32[] constant(0)
|
||||
broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
|
||||
param_9.415 = s32[3]{0} parameter(9)
|
||||
compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
|
||||
constant.6846 = pred[] constant(true)
|
||||
reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
|
||||
param_5.528 = f32[2,512]{1,0} parameter(5)
|
||||
slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
|
||||
bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
|
||||
constant.5418 = f32[] constant(0)
|
||||
broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
|
||||
select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
|
||||
add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
|
||||
param_0.299 = s32[] parameter(0)
|
||||
constant.5157 = s32[] constant(11)
|
||||
dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
|
||||
slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
|
||||
constant.6800 = s32[] constant(0)
|
||||
broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
|
||||
param_8.484 = s32[3]{0} parameter(8)
|
||||
compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
|
||||
constant.6798 = pred[] constant(true)
|
||||
reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
|
||||
param_3.1169 = f32[2,512]{1,0} parameter(3)
|
||||
slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
|
||||
bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
|
||||
select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
|
||||
add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
|
||||
constant.5154 = s32[] constant(10)
|
||||
dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
|
||||
slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
|
||||
constant.6794 = s32[] constant(0)
|
||||
broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
|
||||
param_7.478 = s32[3]{0} parameter(7)
|
||||
compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
|
||||
constant.6793 = pred[] constant(true)
|
||||
reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
|
||||
param_2.1685 = f32[2,512]{1,0} parameter(2)
|
||||
slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
|
||||
bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
|
||||
select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
|
||||
add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
|
||||
constant.5153 = s32[] constant(9)
|
||||
dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
|
||||
slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
|
||||
constant.6788 = s32[] constant(0)
|
||||
broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
|
||||
param_6.495 = s32[3]{0} parameter(6)
|
||||
compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
|
||||
constant.6786 = pred[] constant(true)
|
||||
reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
|
||||
param_1.1408 = f32[2,512]{1,0} parameter(1)
|
||||
slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
|
||||
bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
|
||||
select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
|
||||
add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
|
||||
constant.5152 = s32[] constant(8)
|
||||
ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
|
||||
slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
|
||||
param_6.483 = pred[] parameter(6)
|
||||
broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
|
||||
param_5.525 = f32[2,512]{1,0} parameter(5)
|
||||
slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
|
||||
bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
|
||||
constant.5415 = f32[] constant(0)
|
||||
broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
|
||||
select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
|
||||
add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
|
||||
param_0.265 = s32[] parameter(0)
|
||||
constant.5151 = s32[] constant(7)
|
||||
dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
|
||||
slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
|
||||
constant.6782 = s32[] constant(0)
|
||||
broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
|
||||
param_9.391 = s32[3]{0} parameter(9)
|
||||
compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
|
||||
constant.6781 = pred[] constant(true)
|
||||
reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
|
||||
param_3.1167 = f32[2,512]{1,0} parameter(3)
|
||||
slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
|
||||
bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
|
||||
select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
|
||||
add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
|
||||
constant.5150 = s32[] constant(6)
|
||||
dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
|
||||
slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
|
||||
constant.6776 = s32[] constant(0)
|
||||
broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
|
||||
param_8.464 = s32[3]{0} parameter(8)
|
||||
compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
|
||||
constant.6775 = pred[] constant(true)
|
||||
reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
|
||||
param_2.1684 = f32[2,512]{1,0} parameter(2)
|
||||
slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
|
||||
bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
|
||||
select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
|
||||
add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
|
||||
constant.5149 = s32[] constant(5)
|
||||
dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
|
||||
slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
|
||||
constant.6770 = s32[] constant(0)
|
||||
broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
|
||||
param_7.458 = s32[3]{0} parameter(7)
|
||||
compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
|
||||
constant.6769 = pred[] constant(true)
|
||||
reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
|
||||
param_1.1405 = f32[2,512]{1,0} parameter(1)
|
||||
slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
|
||||
bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
|
||||
select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
|
||||
add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
|
||||
constant.5148 = s32[] constant(4)
|
||||
ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param_0.0 = s32[] parameter(0)
|
||||
param_1.0 = f32[2,512]{1,0} parameter(1)
|
||||
param_2.0 = f32[2,512]{1,0} parameter(2)
|
||||
param_3.0 = f32[2,512]{1,0} parameter(3)
|
||||
param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
|
||||
param_5.0 = f32[2,512]{1,0} parameter(5)
|
||||
param_6.0 = s32[3]{0} parameter(6)
|
||||
param_7.0 = s32[3]{0} parameter(7)
|
||||
param_8.0 = s32[3]{0} parameter(8)
|
||||
param_9.0 = s32[3]{0} parameter(9)
|
||||
fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
|
||||
param_10 = pred[] parameter(10)
|
||||
fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
|
||||
ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user