[XLA:GPU] Do not fuse loop fusions with different output shapes.
PiperOrigin-RevId: 209724594
This commit is contained in:
parent
611459fd33
commit
05f8ea8e95
@ -187,6 +187,19 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Multi-output loop fusions must have equal output shapes to be lowered.
|
||||
if (instr1->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
Shape shape1 = instr1->IsMultiOutputFusion()
|
||||
? instr1->shape().tuple_shapes(0)
|
||||
: instr1->shape();
|
||||
Shape shape2 = instr2->IsMultiOutputFusion()
|
||||
? instr2->shape().tuple_shapes(0)
|
||||
: instr2->shape();
|
||||
if (!ShapeUtil::Equal(shape1, shape2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Do this check last, as it may be expensive.
|
||||
return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2);
|
||||
}
|
||||
|
||||
@ -256,6 +256,90 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
|
||||
op::Tuple(op::Multiply(), op::Divide()));
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
}
|
||||
|
||||
fused_computation_2 {
|
||||
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
const.2 = f32[] constant(0)
|
||||
ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
|
||||
fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
|
||||
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
|
||||
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
|
||||
}
|
||||
|
||||
fused_computation_2 {
|
||||
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
const.2 = f32[] constant(0)
|
||||
ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
|
||||
fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
|
||||
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
|
||||
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
|
||||
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
SCOPED_TRACE(module->ToString());
|
||||
const HloInstruction* fusion =
|
||||
module->entry_computation()->root_instruction()->operand(0)->operand(0);
|
||||
ASSERT_TRUE(fusion->IsMultiOutputFusion());
|
||||
EXPECT_THAT(fusion->fused_expression_root(),
|
||||
op::Tuple(op::Multiply(), op::Exp(), op::Add()));
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest,
|
||||
MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
|
||||
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
|
||||
}
|
||||
|
||||
fused_computation_2 {
|
||||
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
const.2 = f32[] constant(0)
|
||||
ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
|
||||
fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
|
||||
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
|
||||
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
|
||||
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
ENTRY reduce {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user