[XLA] Don't create invalid output fusions
We were incorrectly allowing Ax+B fusions where B was also the `Ax` dot operation. PiperOrigin-RevId: 222893743
This commit is contained in:
parent
11f77c8e22
commit
c81c9a4cf2
@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) {
|
|||||||
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
|
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
|
||||||
|
return hlo_instr.user_count() == 1 &&
|
||||||
|
absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
bool CanBeOutputFused(const HloInstruction* producer,
|
bool CanBeOutputFused(const HloInstruction* producer,
|
||||||
const HloInstruction* consumer) {
|
const HloInstruction* consumer) {
|
||||||
return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) &&
|
return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) &&
|
||||||
producer->user_count() == 1;
|
HasExactlyOneUse(*producer) == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
|
bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
|
||||||
|
@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
|
|||||||
Not(op::Fusion()));
|
Not(op::Fusion()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest,
|
||||||
|
DotOperationFusion_DontOutputFuseDuplicateOperands) {
|
||||||
|
absl::string_view module_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[50,60]{1,0} parameter(0)
|
||||||
|
b = f32[60,1]{1,0} parameter(1)
|
||||||
|
c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
ROOT d = f32[50,1]{1,0} add(c, c)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||||
|
CpuInstructionFusion().Run(module.get()));
|
||||||
|
EXPECT_FALSE(fused_something);
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
Not(op::Fusion()));
|
||||||
|
}
|
||||||
|
|
||||||
struct GatherLoopFusionTestSpec {
|
struct GatherLoopFusionTestSpec {
|
||||||
string test_name;
|
string test_name;
|
||||||
string hlo_computation_text;
|
string hlo_computation_text;
|
||||||
|
@ -181,7 +181,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} else if (consumer->operand_count() == 2 &&
|
} else if (consumer->operand_count() == 2 &&
|
||||||
consumer->opcode() == HloOpcode::kAdd) {
|
consumer->opcode() == HloOpcode::kAdd &&
|
||||||
|
consumer->operand(other_operand_index) != producer) {
|
||||||
// Fuse a bias add into the output of the dot.
|
// Fuse a bias add into the output of the dot.
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -358,6 +358,29 @@ TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) {
|
|||||||
op::Parameter()));
|
op::Parameter()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest,
|
||||||
|
DotOperationFusion_DontOutputFuseDuplicateOperands) {
|
||||||
|
absl::string_view module_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[50,60]{1,0} parameter(0)
|
||||||
|
b = f32[60,1]{1,0} parameter(1)
|
||||||
|
c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
ROOT d = f32[50,1]{1,0} add(c, c)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool fused_something,
|
||||||
|
GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get()));
|
||||||
|
EXPECT_FALSE(fused_something);
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
Not(op::Fusion()));
|
||||||
|
}
|
||||||
|
|
||||||
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
|
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
|
||||||
// duplicated and fused into both reduces.
|
// duplicated and fused into both reduces.
|
||||||
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
|
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
|
||||||
|
Loading…
Reference in New Issue
Block a user