[XLA:GPU] Do not fuse loop fusions with different output shapes.

PiperOrigin-RevId: 209724594
This commit is contained in:
Thomas Joerg 2018-08-22 00:36:58 -07:00 committed by TensorFlower Gardener
parent 611459fd33
commit 05f8ea8e95
2 changed files with 97 additions and 0 deletions

View File

@ -187,6 +187,19 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
return false;
}
// Multi-output loop fusions must have equal output shapes to be lowered.
if (instr1->fusion_kind() == HloInstruction::FusionKind::kLoop) {
Shape shape1 = instr1->IsMultiOutputFusion()
? instr1->shape().tuple_shapes(0)
: instr1->shape();
Shape shape2 = instr2->IsMultiOutputFusion()
? instr2->shape().tuple_shapes(0)
: instr2->shape();
if (!ShapeUtil::Equal(shape1, shape2)) {
return false;
}
}
// Do this check last, as it may be expensive.
return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2);
}

View File

@ -256,6 +256,90 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
}
fused_computation_2 {
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
const.2 = f32[] constant(0)
ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
}
ENTRY entry {
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2)
})"))
.ValueOrDie();
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
}
fused_computation_2 {
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
const.2 = f32[] constant(0)
ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2)
}
ENTRY entry {
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
})"))
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
SCOPED_TRACE(module->ToString());
const HloInstruction* fusion =
module->entry_computation()->root_instruction()->operand(0)->operand(0);
ASSERT_TRUE(fusion->IsMultiOutputFusion());
EXPECT_THAT(fusion->fused_expression_root(),
op::Tuple(op::Multiply(), op::Exp(), op::Add()));
}
TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
}
fused_computation_2 {
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
const.2 = f32[] constant(0)
ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
}
ENTRY entry {
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
})"))
.ValueOrDie();
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
ENTRY reduce {