Relax a check to enable an optimization on all element type.
This commit is contained in:
parent
e603870c4f
commit
c2632730f8
@ -162,8 +162,7 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||
return false;
|
||||
}
|
||||
// The elementwise output shapes must be the same (including layout).
|
||||
// TODO(tjoerg): Further relax the constraint. The datatype does not matter.
|
||||
return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1),
|
||||
return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
|
||||
get_loop_shape(instr_2));
|
||||
}
|
||||
|
||||
|
@ -470,6 +470,41 @@ TEST_F(GpuFusibleTest,
|
||||
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest,
|
||||
ShapesCompatibleForMultiOutputFusion_MultiOutputLoopFusion_DifferentElementType) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
|
||||
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
|
||||
}
|
||||
|
||||
fused_computation_2 {
|
||||
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
const.2 = f32[] constant(0)
|
||||
broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2), dimensions={}
|
||||
add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
|
||||
ROOT convert = s32[8,1,5,16,1,1]{5,4,3,2,1,0} convert(add)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
|
||||
fusion.2 = s32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
|
||||
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
|
||||
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
|
||||
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
const HloInstruction* fusion_1 =
|
||||
module->entry_computation()->root_instruction()->operand(0)->operand(0);
|
||||
const HloInstruction* fusion_2 =
|
||||
module->entry_computation()->root_instruction()->operand(2);
|
||||
EXPECT_NE(fusion_1, fusion_2);
|
||||
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
ENTRY reduce {
|
||||
|
@ -128,6 +128,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return equal;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
|
||||
if (!equal && VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
|
||||
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
|
||||
}
|
||||
|
||||
return equal;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
|
||||
|
@ -304,6 +304,9 @@ class ShapeUtil {
|
||||
// Returns whether the lhs and rhs shapes are identical.
|
||||
static bool Equal(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// As Equal, but does not compare the element type.
|
||||
static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
|
||||
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
|
@ -176,6 +176,27 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
|
||||
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, EqualIgnoringElementType) {
|
||||
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
|
||||
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
|
||||
ShapeUtil::MakeShapeWithLayout(S32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
|
||||
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, UnequalIgnoringElementType) {
|
||||
EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
|
||||
EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, EqualDynamicShapes) {
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}),
|
||||
|
Loading…
Reference in New Issue
Block a user