Ensure that the multi-instruction fuse can take shared inputs (#11748)
* Ensure that the multi-instruction fuse can take shared inputs Note that the fuse action only works when the shared input / constant appears after all of its consumers in the list of instructions. * Add a comment describing the test
This commit is contained in:
parent
2381ce5c33
commit
b26f9cd44b
@ -522,6 +522,42 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
|
||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
||||
}
|
||||
|
||||
// When a constant (or other op) which has multiple users is imported
|
||||
// into a fusion, it should remain shared, rather than being duplicated
|
||||
// within the fusion.
|
||||
XLA_TEST_F(FusionTest, SharedConstant) {
|
||||
auto hlo_module = CreateNewModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
|
||||
auto const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
|
||||
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
|
||||
auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
|
||||
auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
|
||||
hlo_module->AddEntryComputation(builder.Build())
|
||||
->CreateFusionInstruction(
|
||||
{add4, add3, add2, add1, const1},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
|
||||
HloComputation* entry_comp = hlo_module->entry_computation();
|
||||
|
||||
// entry computation contains the constant(0) and the fusion
|
||||
EXPECT_EQ(entry_comp->instructions().size(), 2);
|
||||
|
||||
// fused instruction contains the constant(2), the parameter, and 4 adds
|
||||
EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6);
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
|
||||
*ExecuteAndTransfer(std::move(hlo_module), {}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
|
||||
|
||||
XLA_TEST_F(FusionTest, Subtract2D) {
|
||||
|
Loading…
Reference in New Issue
Block a user