Relax a check to enable an optimization on all element type.

This commit is contained in:
Frederic Bastien 2019-06-28 11:48:27 -07:00
parent e603870c4f
commit c2632730f8
5 changed files with 71 additions and 2 deletions

View File

@ -162,8 +162,7 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
return false; return false;
} }
// The elementwise output shapes must be the same (including layout). // The elementwise output shapes must be the same (including layout).
// TODO(tjoerg): Further relax the constraint. The datatype does not matter. return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1),
get_loop_shape(instr_2)); get_loop_shape(instr_2));
} }

View File

@ -470,6 +470,41 @@ TEST_F(GpuFusibleTest,
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
} }
TEST_F(GpuFusibleTest,
ShapesCompatibleForMultiOutputFusion_MultiOutputLoopFusion_DifferentElementType) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
}
fused_computation_2 {
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
const.2 = f32[] constant(0)
broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2), dimensions={}
add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
ROOT convert = s32[8,1,5,16,1,1]{5,4,3,2,1,0} convert(add)
}
ENTRY entry {
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
fusion.2 = s32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
})"))
.ValueOrDie();
const HloInstruction* fusion_1 =
module->entry_computation()->root_instruction()->operand(0)->operand(0);
const HloInstruction* fusion_2 =
module->entry_computation()->root_instruction()->operand(2);
EXPECT_NE(fusion_1, fusion_2);
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
}
TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) { TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
ENTRY reduce { ENTRY reduce {

View File

@ -128,6 +128,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return equal; return equal;
} }
/* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
const Shape& rhs) {
bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
if (!equal && VLOG_IS_ON(3)) {
VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
}
return equal;
}
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) { const Shape& rhs) {
bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs); bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);

View File

@ -304,6 +304,9 @@ class ShapeUtil {
// Returns whether the lhs and rhs shapes are identical. // Returns whether the lhs and rhs shapes are identical.
static bool Equal(const Shape& lhs, const Shape& rhs); static bool Equal(const Shape& lhs, const Shape& rhs);
// As Equal, but does not compare the element type.
static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32. // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);

View File

@ -176,6 +176,27 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
} }
TEST(ShapeUtilTest, EqualIgnoringElementType) {
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
ShapeUtil::MakeShapeWithLayout(S32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
}
TEST(ShapeUtilTest, UnequalIgnoringElementType) {
EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType(
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType(
ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
}
TEST(ShapeUtilTest, EqualDynamicShapes) { TEST(ShapeUtilTest, EqualDynamicShapes) {
EXPECT_TRUE( EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}),