476 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			476 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/instruction_fusion.h"
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 | |
| #include "tensorflow/compiler/xla/service/hlo_parser.h"
 | |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | |
| 
 | |
| namespace xla {
 | |
| 
 | |
| namespace op = xla::testing::opcode_matchers;
 | |
| 
 | |
| using InstructionFusionTest = HloTestBase;
 | |
| 
 | |
| // Subclass of InstructionFusion exposing the protected methods Fuse and
 | |
| // FuseIntoMultiOutput for testing.
 | |
| class InstructionFusionForTesting : public InstructionFusion {
 | |
|  public:
 | |
|   explicit InstructionFusionForTesting(HloModule* module)
 | |
|       : InstructionFusion(InstructionFusion::IsExpensive) {
 | |
|     module_ = module;
 | |
|     computation_ = module->entry_computation();
 | |
|   }
 | |
| 
 | |
|   HloInstruction* Fuse(HloInstruction* producer,
 | |
|                        HloInstruction* consumer) override {
 | |
|     return InstructionFusion::Fuse(producer, consumer);
 | |
|   }
 | |
| 
 | |
|   HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
 | |
|                                       HloInstruction* consumer) override {
 | |
|     return InstructionFusion::FuseIntoMultiOutput(producer, consumer);
 | |
|   }
 | |
| };
 | |
| 
 | |
| TEST_F(InstructionFusionTest, FuseInstructions) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY entry_computation {
 | |
|     p0 = f32[4,3]{1,0} parameter(0)
 | |
|     add = f32[4,3]{1,0} add(p0, p0)
 | |
|     ROOT sub = f32[4,3]{1,0} subtract(add, p0)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   HloInstruction* sub = module->entry_computation()->root_instruction();
 | |
|   HloInstruction* add = sub->mutable_operand(0);
 | |
|   HloInstruction* fusion =
 | |
|       InstructionFusionForTesting(module.get()).Fuse(add, sub);
 | |
| 
 | |
|   ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
 | |
|   EXPECT_THAT(fusion->fused_expression_root(),
 | |
|               op::Subtract(op::Add(), op::Parameter()))
 | |
|       << module->ToString();
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   fused_computation {
 | |
|     p1 = f32[4,3] parameter(0)
 | |
|     add = f32[4,3] add(p1, p1)
 | |
|   }
 | |
|   ENTRY entry_computation {
 | |
|     p0 = f32[4,3] parameter(0)
 | |
|     abs = f32[4,3] abs(p0)
 | |
|     ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   HloInstruction* abs = root->mutable_operand(0);
 | |
|   HloInstruction* fusion =
 | |
|       InstructionFusionForTesting(module.get()).Fuse(abs, root);
 | |
| 
 | |
|   ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
 | |
|   EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs()))
 | |
|       << module->ToString();
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY entry_computation {
 | |
|     p0 = f32[4,3]{1,0} parameter(0)
 | |
|     abs = f32[4,3]{1,0} abs(p0)
 | |
|     tanh = f32[4,3]{1,0} tanh(abs)
 | |
|     ROOT add = f32[4,3]{1,0} add(abs, tanh)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   HloInstruction* abs = root->mutable_operand(0);
 | |
|   HloInstruction* tanh = root->mutable_operand(1);
 | |
|   HloInstruction* fusion =
 | |
|       InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh);
 | |
| 
 | |
|   ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
 | |
|   EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs()))
 | |
|       << module->ToString();
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) {
 | |
|   HloComputation::Builder builder(TestName());
 | |
|   auto shape = ShapeUtil::MakeShape(F32, {16, 16});
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
 | |
|   HloInstruction* binary1 = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
 | |
|   auto token = builder.AddInstruction(HloInstruction::CreateToken());
 | |
|   auto send =
 | |
|       builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
 | |
|   builder.AddInstruction(HloInstruction::CreateSendDone(send));
 | |
|   HloInstruction* unary = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
 | |
| 
 | |
|   auto module = CreateNewVerifiedModule();
 | |
|   auto computation = module->AddEntryComputation(builder.Build());
 | |
|   EXPECT_EQ(unary, computation->root_instruction());
 | |
|   EXPECT_FALSE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
| }
 | |
| 
 | |
| // Counts the number of HLO ops with a given op code in the specified module.
 | |
| static int Count(const HloModule& module, HloOpcode op) {
 | |
|   int count = 0;
 | |
|   for (const auto* computation : module.computations()) {
 | |
|     for (const auto* instruction : computation->instructions()) {
 | |
|       if (instruction->opcode() == op) {
 | |
|         ++count;
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   return count;
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY OutputFusion {
 | |
|     p0 = f32[4,3]{1,0} parameter(0)
 | |
|     add = f32[4,3]{1,0} add(p0, p0)
 | |
|     ROOT root = f32[4,3]{1,0} subtract(add, add)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   // Expect the add and subtraction to be fused.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
 | |
| 
 | |
|   // Make sure the add hasn't been duplicated.
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
 | |
|   // Make sure we do not duplicate the add, as we cannot fuse through the rng.
 | |
|   //
 | |
|   // (p0, p1) -> add -------------------------> sub
 | |
|   //                 \-> abs1 -> rng -> abs2 -/
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY OutputFusion {
 | |
|     p0 = f32[] parameter(0)
 | |
|     p1 = f32[] parameter(1)
 | |
|     add = f32[] add(p0, p1)
 | |
|     abs1 = f32[] abs(add)
 | |
|     rng = f32[] rng(p1, abs1), distribution=rng_uniform
 | |
|     abs2 = f32[] abs(rng)
 | |
|     abs3 = f32[] abs(rng)    
 | |
|     ROOT root = f32[] subtract(abs2, add)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   // We expect abs2 to be fused into root.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   EXPECT_THAT(root, op::Fusion());
 | |
|   EXPECT_THAT(root->fused_expression_root(),
 | |
|               op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
 | |
|       << module->ToString();
 | |
| 
 | |
|   // Make sure the add hasn't been duplicated.
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
 | |
| 
 | |
|   // Use a log node with a second consumer to break the fusion.
 | |
|   //
 | |
|   // (p0, p1) -> add -------------------------> sub
 | |
|   //                 \-> abs1 -> log -> abs2 -/
 | |
|   //                                 \-> send -> send-done
 | |
|   module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY OutputFusion {
 | |
|     p0 = f32[4,3]{1,0} parameter(0)
 | |
|     p1 = f32[4,3]{1,0} parameter(1)
 | |
|     add = f32[4,3]{1,0} add(p0, p1)
 | |
|     abs1 = f32[4,3]{1,0} abs(add)
 | |
|     log = f32[4,3]{1,0} log(abs1)
 | |
|     token0 = token[] after-all()
 | |
|     send = f32[4,3]{1,0} send(log, token0), channel_id=1
 | |
|     send-done = token[] send-done(send), channel_id=1
 | |
|     abs2 = f32[4,3]{1,0} abs(log)
 | |
|     ROOT root = f32[4,3]{1,0} subtract(abs2, add)
 | |
|   })")
 | |
|                .ValueOrDie();
 | |
| 
 | |
|   // We expect abs2 to be fused into root and abs1 to be fused into log.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kFusion), 2) << module->ToString();
 | |
| 
 | |
|   // Make sure the add hasn't been duplicated.
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
 | |
| 
 | |
|   // Make sure we still fuse ops where one operand in the chain to the producer
 | |
|   // can't be fused.
 | |
|   //
 | |
|   // (p0, p1) ---> add1 -----------> sub
 | |
|   //          \         \-> add2 -/
 | |
|   //           \-> log -/
 | |
|   //                   \-> send -> send-done
 | |
|   module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY OutputFusion {
 | |
|     p0 = f32[4,3]{1,0} parameter(0)
 | |
|     p1 = f32[4,3]{1,0} parameter(1)
 | |
|     add1 = f32[4,3]{1,0} add(p0, p1)
 | |
|     log = f32[4,3]{1,0} log(p0)
 | |
|     token0 = token[] after-all()
 | |
|     send = f32[4,3]{1,0} send(log, token0), channel_id=1
 | |
|     send-done = token[] send-done(send), channel_id=1
 | |
|     add2 = f32[4,3]{1,0} add(log, add1)
 | |
|     ROOT root = f32[4,3]{1,0} subtract(add1, add2)
 | |
|   })")
 | |
|                .ValueOrDie();
 | |
| 
 | |
|   // Expect the add1 and add2 to be fused into root.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
 | |
| 
 | |
|   // Make sure we didn't duplicate any adds.
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
 | |
| 
 | |
|   // A variant of the above that allows the algorithm to put add2 into the set
 | |
|   // of unfusible ops to short-circuit the decision whether add1 should be fused
 | |
|   // into sub2.
 | |
|   //
 | |
|   //             /---------------\
 | |
|   // (p0, p1) ---> add1 ---> add2 ------> sub2
 | |
|   //                             \------> sub1
 | |
|   //                              log -/
 | |
|   //                                  \-> send
 | |
|   module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY OutputFusion {
 | |
|     p0 = f32[4,3]{1,0} parameter(0)
 | |
|     p1 = f32[4,3]{1,0} parameter(1)
 | |
|     add1 = f32[4,3]{1,0} add(p0, p1)
 | |
|     add2 = f32[4,3]{1,0} add(add1, p1)
 | |
|     log = f32[4,3]{1,0} log(add2)
 | |
|     token0 = token[] after-all()
 | |
|     send = f32[4,3]{1,0} send(log, token0), channel_id=1
 | |
|     send-done = token[] send-done(send), channel_id=1
 | |
|     sub1 = f32[4,3]{1,0} subtract(log, add2)
 | |
|     sub2 = f32[4,3]{1,0} subtract(add2, add1)
 | |
|     ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
 | |
|   })")
 | |
|                .ValueOrDie();
 | |
| 
 | |
|   // Expect sub1 and sub2 to be fused into root.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
|   root = module->entry_computation()->root_instruction();
 | |
|   EXPECT_THAT(root, op::Fusion());
 | |
|   EXPECT_THAT(root->fused_expression_root(),
 | |
|               op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
 | |
|                         op::Subtract(op::Parameter(), op::Parameter())))
 | |
|       << module->ToString();
 | |
| 
 | |
|   // Make sure we didn't duplicate any adds.
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
 | |
|   HloComputation::Builder builder(TestName());
 | |
|   auto shape = ShapeUtil::MakeShape(F32, {16, 16});
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
 | |
|   HloInstruction* unary1 = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
 | |
|   auto token = builder.AddInstruction(HloInstruction::CreateToken());
 | |
|   auto send =
 | |
|       builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0));
 | |
|   builder.AddInstruction(HloInstruction::CreateSendDone(send));
 | |
|   HloInstruction* unary2 = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
 | |
| 
 | |
|   auto module = CreateNewVerifiedModule();
 | |
|   auto computation = module->AddEntryComputation(builder.Build());
 | |
|   EXPECT_EQ(unary2, computation->root_instruction());
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie());
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
 | |
|   auto shape = ShapeUtil::MakeShape(F32, {16, 16});
 | |
|   auto small_shape = ShapeUtil::MakeShape(F32, {16});
 | |
|   HloComputation::Builder builder(TestName());
 | |
|   auto param0 = builder.AddInstruction(
 | |
|       HloInstruction::CreateParameter(0, small_shape, "0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
 | |
|   auto broadcast = builder.AddInstruction(
 | |
|       HloInstruction::CreateBroadcast(shape, param0, {0}));
 | |
|   HloInstruction* binary1 = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, broadcast, param1));
 | |
|   auto token = builder.AddInstruction(HloInstruction::CreateToken());
 | |
|   auto send =
 | |
|       builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
 | |
|   builder.AddInstruction(HloInstruction::CreateSendDone(send));
 | |
|   HloInstruction* unary = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
 | |
| 
 | |
|   auto module = CreateNewVerifiedModule();
 | |
|   auto computation = module->AddEntryComputation(builder.Build());
 | |
|   EXPECT_EQ(unary, computation->root_instruction());
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie());
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
 | |
|   // Make sure we do duplicate the add of the same values, even though we cannot
 | |
|   // fuse through the rng.
 | |
|   //
 | |
|   // p0 -> add -------------------------> sub
 | |
|   //           \-> abs1 -> rng -> abs2 -/
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY OutputFusion {
 | |
|     p0 = f32[] parameter(0)
 | |
|     add = f32[] add(p0, p0)
 | |
|     abs1 = f32[] abs(add)
 | |
|     rng = f32[] rng(p0, abs1), distribution=rng_uniform
 | |
|     abs2 = f32[] abs(rng)
 | |
|     abs3 = f32[] abs(rng)
 | |
|     ROOT root = f32[] subtract(abs2, add)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   // We expect abs2 to be fused into root.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   EXPECT_THAT(root, op::Fusion());
 | |
|   EXPECT_THAT(root->fused_expression_root(),
 | |
|               op::Subtract(op::Abs(op::Parameter()),
 | |
|                            op::Add(op::Parameter(), op::Parameter())))
 | |
|       << module->ToString();
 | |
| 
 | |
|   // Make sure the add has been duplicated.
 | |
|   EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY Test {
 | |
|     p0 = f32[100] parameter(0)
 | |
|     p1 = f32[100] parameter(1)
 | |
|     add = f32[100] add(p0, p1)
 | |
|     slice1 = f32[99] slice(add), slice={[0:99:1]}
 | |
|     slice2 = f32[99] slice(add), slice={[1:100:1]}
 | |
|     ROOT add2 = f32[99] add(slice1, slice2)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
| 
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   // 'add' would originally need to be duplicated if fused. However after its
 | |
|   // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one
 | |
|   // user and can now be also fused.
 | |
|   EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter()));
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY Test {
 | |
|     p0 = f32[100] parameter(0)
 | |
|     p1 = f32[100] parameter(1)
 | |
|     add = f32[100] add(p0, p1)
 | |
|     slice1 = f32[99] slice(add), slice={[0:99:1]}
 | |
|     slice2 = f32[99] slice(add), slice={[1:100:1]}
 | |
|     ROOT add2 = f32[99] add(slice1, slice2)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
| 
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   // 'add' would originally need to be duplicated if fused. However after its
 | |
|   // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one
 | |
|   // user and can now be also fused.
 | |
|   EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter()));
 | |
| }
 | |
| 
 | |
| TEST_F(InstructionFusionTest,
 | |
|        WideningConvertsAreAlwaysDuplicableIntoConsumers) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|   HloModule test_module
 | |
|   ENTRY Test {
 | |
|     p0 = f16[100] parameter(0)
 | |
|     c = f32[100] convert(p0)
 | |
|     add = f32[100] add(c, c)
 | |
|     ROOT mul = f32[100] multiply(c, c)
 | |
|   })")
 | |
|                     .ValueOrDie();
 | |
| 
 | |
|   // The convert should be fused into the add and mul, even though may_duplicate
 | |
|   // is false, because it's always beneficial to fuse/duplicate widening
 | |
|   // converts into consumers.
 | |
|   EXPECT_TRUE(
 | |
|       InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false)
 | |
|           .Run(module.get())
 | |
|           .ValueOrDie())
 | |
|       << module->ToString();
 | |
| 
 | |
|   HloInstruction* root = module->entry_computation()->root_instruction();
 | |
|   EXPECT_THAT(root, op::Fusion(op::Parameter()));
 | |
| }
 | |
| 
 | |
| }  // namespace xla
 |