Ensure that the multi-instruction fuse can take shared inputs (#11748)

* Ensure that the multi-instruction fuse can take shared inputs

Note that the fuse action only works when the shared input / constant
appears after all of its consumers in the list of instructions.

* Add a comment describing the test
This commit is contained in:
David Norman 2017-07-27 17:58:23 +01:00 committed by Vijay Vasudevan
parent 2381ce5c33
commit b26f9cd44b

View File

@ -522,6 +522,42 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
*ExecuteAndTransfer(std::move(hlo_module), {}));
}
// When a constant (or other op) which has multiple users is imported
// into a fusion, it should remain shared, rather than being duplicated
// within the fusion.
XLA_TEST_F(FusionTest, SharedConstant) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
auto const1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(
{add4, add3, add2, add1, const1},
HloInstruction::FusionKind::kLoop);
HloComputation* entry_comp = hlo_module->entry_computation();
// entry computation contains the constant(0) and the fusion
EXPECT_EQ(entry_comp->instructions().size(), 2);
// fused instruction contains the constant(2), the parameter, and 4 adds
EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6);
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
*ExecuteAndTransfer(std::move(hlo_module), {}));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
XLA_TEST_F(FusionTest, Subtract2D) {