[XLA] Don't create invalid output fusions

We were incorrectly allowing Ax+B fusions where B was also the `Ax` dot
operation.

PiperOrigin-RevId: 222893743
This commit is contained in:
Sanjoy Das 2018-11-26 15:05:10 -08:00 committed by TensorFlower Gardener
parent 11f77c8e22
commit c81c9a4cf2
4 changed files with 53 additions and 2 deletions

View File

@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) {
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
}
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
return hlo_instr.user_count() == 1 &&
absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1;
}
bool CanBeOutputFused(const HloInstruction* producer,
const HloInstruction* consumer) {
return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) &&
producer->user_count() == 1;
HasExactlyOneUse(*producer) == 1;
}
bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {

View File

@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
Not(op::Fusion()));
}
TEST_F(InstructionFusionTest,
DotOperationFusion_DontOutputFuseDuplicateOperands) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
a = f32[50,60]{1,0} parameter(0)
b = f32[60,1]{1,0} parameter(1)
c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT d = f32[50,1]{1,0} add(c, c)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
CpuInstructionFusion().Run(module.get()));
EXPECT_FALSE(fused_something);
EXPECT_THAT(module->entry_computation()->root_instruction(),
Not(op::Fusion()));
}
struct GatherLoopFusionTestSpec {
string test_name;
string hlo_computation_text;

View File

@ -181,7 +181,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return true;
}
} else if (consumer->operand_count() == 2 &&
consumer->opcode() == HloOpcode::kAdd) {
consumer->opcode() == HloOpcode::kAdd &&
consumer->operand(other_operand_index) != producer) {
// Fuse a bias add into the output of the dot.
return true;
}

View File

@ -358,6 +358,29 @@ TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) {
op::Parameter()));
}
TEST_F(InstructionFusionTest,
DotOperationFusion_DontOutputFuseDuplicateOperands) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
a = f32[50,60]{1,0} parameter(0)
b = f32[60,1]{1,0} parameter(1)
c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT d = f32[50,1]{1,0} add(c, c)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(
bool fused_something,
GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get()));
EXPECT_FALSE(fused_something);
EXPECT_THAT(module->entry_computation()->root_instruction(),
Not(op::Fusion()));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
// duplicated and fused into both reduces.
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {