Relax a check to enable an optimization on all element type.
This commit is contained in:
parent
e603870c4f
commit
c2632730f8
@ -162,8 +162,7 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// The elementwise output shapes must be the same (including layout).
|
// The elementwise output shapes must be the same (including layout).
|
||||||
// TODO(tjoerg): Further relax the constraint. The datatype does not matter.
|
return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
|
||||||
return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1),
|
|
||||||
get_loop_shape(instr_2));
|
get_loop_shape(instr_2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -470,6 +470,41 @@ TEST_F(GpuFusibleTest,
|
|||||||
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
|
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GpuFusibleTest,
|
||||||
|
ShapesCompatibleForMultiOutputFusion_MultiOutputLoopFusion_DifferentElementType) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||||
|
fused_computation_1 {
|
||||||
|
p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||||
|
mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
|
||||||
|
exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
|
||||||
|
ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
|
||||||
|
}
|
||||||
|
|
||||||
|
fused_computation_2 {
|
||||||
|
p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||||
|
const.2 = f32[] constant(0)
|
||||||
|
broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2), dimensions={}
|
||||||
|
add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
|
||||||
|
ROOT convert = s32[8,1,5,16,1,1]{5,4,3,2,1,0} convert(add)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
|
||||||
|
fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
|
||||||
|
fusion.2 = s32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
|
||||||
|
gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
|
||||||
|
gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
|
||||||
|
ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
|
||||||
|
})"))
|
||||||
|
.ValueOrDie();
|
||||||
|
const HloInstruction* fusion_1 =
|
||||||
|
module->entry_computation()->root_instruction()->operand(0)->operand(0);
|
||||||
|
const HloInstruction* fusion_2 =
|
||||||
|
module->entry_computation()->root_instruction()->operand(2);
|
||||||
|
EXPECT_NE(fusion_1, fusion_2);
|
||||||
|
EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) {
|
TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) {
|
||||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||||
ENTRY reduce {
|
ENTRY reduce {
|
||||||
|
@ -128,6 +128,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
|||||||
return equal;
|
return equal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
|
||||||
|
const Shape& rhs) {
|
||||||
|
bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
|
||||||
|
if (!equal && VLOG_IS_ON(3)) {
|
||||||
|
VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
|
||||||
|
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return equal;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
|
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
|
||||||
const Shape& rhs) {
|
const Shape& rhs) {
|
||||||
bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
|
bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
|
||||||
|
@ -304,6 +304,9 @@ class ShapeUtil {
|
|||||||
// Returns whether the lhs and rhs shapes are identical.
|
// Returns whether the lhs and rhs shapes are identical.
|
||||||
static bool Equal(const Shape& lhs, const Shape& rhs);
|
static bool Equal(const Shape& lhs, const Shape& rhs);
|
||||||
|
|
||||||
|
// As Equal, but does not compare the element type.
|
||||||
|
static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
|
||||||
|
|
||||||
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
|
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
|
||||||
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||||
|
|
||||||
|
@ -176,6 +176,27 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
|
|||||||
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
|
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ShapeUtilTest, EqualIgnoringElementType) {
|
||||||
|
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
|
||||||
|
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
|
||||||
|
ShapeUtil::MakeShapeWithLayout(S32, {4, 3}, {0, 1}),
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
|
||||||
|
EXPECT_TRUE(ShapeUtil::EqualIgnoringElementType(
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||||
|
ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ShapeUtilTest, UnequalIgnoringElementType) {
|
||||||
|
EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType(
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
|
||||||
|
EXPECT_FALSE(ShapeUtil::EqualIgnoringElementType(
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ShapeUtilTest, EqualDynamicShapes) {
|
TEST(ShapeUtilTest, EqualDynamicShapes) {
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}),
|
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}),
|
||||||
|
Loading…
Reference in New Issue
Block a user