[XLA] Don't create invalid output fusions
We were incorrectly allowing Ax+B fusions where B was also the `Ax` dot operation. PiperOrigin-RevId: 222893743
This commit is contained in:
parent
11f77c8e22
commit
c81c9a4cf2
@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) {
|
||||
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
|
||||
}
|
||||
|
||||
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
|
||||
return hlo_instr.user_count() == 1 &&
|
||||
absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1;
|
||||
}
|
||||
|
||||
bool CanBeOutputFused(const HloInstruction* producer,
|
||||
const HloInstruction* consumer) {
|
||||
return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) &&
|
||||
producer->user_count() == 1;
|
||||
HasExactlyOneUse(*producer) == 1;
|
||||
}
|
||||
|
||||
bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
|
||||
|
@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
|
||||
Not(op::Fusion()));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest,
|
||||
DotOperationFusion_DontOutputFuseDuplicateOperands) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
a = f32[50,60]{1,0} parameter(0)
|
||||
b = f32[60,1]{1,0} parameter(1)
|
||||
c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
ROOT d = f32[50,1]{1,0} add(c, c)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||
CpuInstructionFusion().Run(module.get()));
|
||||
EXPECT_FALSE(fused_something);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
Not(op::Fusion()));
|
||||
}
|
||||
|
||||
struct GatherLoopFusionTestSpec {
|
||||
string test_name;
|
||||
string hlo_computation_text;
|
||||
|
@ -181,7 +181,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
return true;
|
||||
}
|
||||
} else if (consumer->operand_count() == 2 &&
|
||||
consumer->opcode() == HloOpcode::kAdd) {
|
||||
consumer->opcode() == HloOpcode::kAdd &&
|
||||
consumer->operand(other_operand_index) != producer) {
|
||||
// Fuse a bias add into the output of the dot.
|
||||
return true;
|
||||
}
|
||||
|
@ -358,6 +358,29 @@ TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) {
|
||||
op::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest,
|
||||
DotOperationFusion_DontOutputFuseDuplicateOperands) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
a = f32[50,60]{1,0} parameter(0)
|
||||
b = f32[60,1]{1,0} parameter(1)
|
||||
c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
ROOT d = f32[50,1]{1,0} add(c, c)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool fused_something,
|
||||
GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get()));
|
||||
EXPECT_FALSE(fused_something);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
Not(op::Fusion()));
|
||||
}
|
||||
|
||||
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
|
||||
// duplicated and fused into both reduces.
|
||||
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
|
||||
|
Loading…
Reference in New Issue
Block a user