No need to have this distinction any longer; you can simply call CreateNewUnverifiedModule or CreateNewVerifiedModule as you please. PiperOrigin-RevId: 221018719
		
			
				
	
	
		
			221 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			221 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/defuser.h"
 | |
| 
 | |
| #include "tensorflow/compiler/xla/literal.h"
 | |
| #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 | |
| #include "tensorflow/compiler/xla/shape_util.h"
 | |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | |
| 
 | |
| namespace op = xla::testing::opcode_matchers;
 | |
| 
 | |
| namespace xla {
 | |
| namespace {
 | |
| 
 | |
| class DefuserTest : public HloTestBase {
 | |
|  protected:
 | |
|   // Returns the number of fusion instructions in the module.
 | |
|   int FusionCount(const HloModule* m) {
 | |
|     int count = 0;
 | |
|     for (HloComputation* computation : m->computations()) {
 | |
|       if (computation->IsFusionComputation()) {
 | |
|         count++;
 | |
|       }
 | |
|     }
 | |
|     return count;
 | |
|   }
 | |
| 
 | |
|   Defuser defuser_;
 | |
|   const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2});
 | |
| };
 | |
| 
 | |
| TEST_F(DefuserTest, NoFusionInstruction) {
 | |
|   auto m = CreateNewVerifiedModule();
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
 | |
|   builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
 | |
| 
 | |
|   m->AddEntryComputation(builder.Build());
 | |
|   EXPECT_EQ(0, FusionCount(m.get()));
 | |
| 
 | |
|   EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie());
 | |
| }
 | |
| 
 | |
| TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) {
 | |
|   auto m = CreateNewVerifiedModule();
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
 | |
|   auto add = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
 | |
| 
 | |
|   auto computation = m->AddEntryComputation(builder.Build());
 | |
|   computation->CreateFusionInstruction({add},
 | |
|                                        HloInstruction::FusionKind::kLoop);
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(), op::Fusion());
 | |
| 
 | |
|   EXPECT_EQ(1, FusionCount(m.get()));
 | |
|   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
 | |
|   EXPECT_EQ(0, FusionCount(m.get()));
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(),
 | |
|               op::Add(op::Parameter(), op::Parameter()));
 | |
| }
 | |
| 
 | |
| TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) {
 | |
|   auto m = CreateNewVerifiedModule();
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
 | |
|   auto add = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
 | |
|   builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
 | |
| 
 | |
|   auto computation = m->AddEntryComputation(builder.Build());
 | |
|   computation->CreateFusionInstruction({add},
 | |
|                                        HloInstruction::FusionKind::kLoop);
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion()));
 | |
| 
 | |
|   EXPECT_EQ(1, FusionCount(m.get()));
 | |
|   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
 | |
|   EXPECT_EQ(0, FusionCount(m.get()));
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(),
 | |
|               op::Negate(op::Add(op::Parameter(), op::Parameter())));
 | |
| }
 | |
| 
 | |
| TEST_F(DefuserTest, NonTrivialFusionInstruction) {
 | |
|   auto m = CreateNewVerifiedModule();
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
 | |
|   auto param3 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
 | |
|   auto add = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
 | |
|   auto negate = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
 | |
|   auto sub = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
 | |
|   auto mul = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
 | |
|   auto div = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
 | |
|   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
 | |
|       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
 | |
|   auto add2 = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
 | |
| 
 | |
|   auto computation = m->AddEntryComputation(builder.Build());
 | |
|   computation->CreateFusionInstruction(
 | |
|       {add2, constant, div, mul, sub, negate, add},
 | |
|       HloInstruction::FusionKind::kLoop);
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(), op::Fusion());
 | |
| 
 | |
|   EXPECT_EQ(1, FusionCount(m.get()));
 | |
|   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
 | |
|   EXPECT_EQ(0, FusionCount(m.get()));
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(),
 | |
|               op::Add(op::Constant(), op::Divide()));
 | |
| }
 | |
| 
 | |
| TEST_F(DefuserTest, MultipleFusionInstructions) {
 | |
|   auto m = CreateNewVerifiedModule();
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
 | |
|   auto param3 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
 | |
|   auto add = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
 | |
|   auto negate = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
 | |
|   auto sub = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
 | |
|   auto mul = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
 | |
|   auto div = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
 | |
|   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
 | |
|       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
 | |
|   auto add2 = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
 | |
| 
 | |
|   auto computation = m->AddEntryComputation(builder.Build());
 | |
|   computation->CreateFusionInstruction({add2, constant, div, mul},
 | |
|                                        HloInstruction::FusionKind::kLoop);
 | |
|   computation->CreateFusionInstruction({sub, negate, add},
 | |
|                                        HloInstruction::FusionKind::kLoop);
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(), op::Fusion());
 | |
| 
 | |
|   EXPECT_EQ(2, FusionCount(m.get()));
 | |
|   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
 | |
|   EXPECT_EQ(0, FusionCount(m.get()));
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(),
 | |
|               op::Add(op::Constant(), op::Divide()));
 | |
| }
 | |
| 
 | |
| TEST_F(DefuserTest, NestedFusionInstructions) {
 | |
|   auto m = CreateNewVerifiedModule();
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto param0 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
 | |
|   auto param1 =
 | |
|       builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
 | |
|   auto add = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
 | |
|   auto negate = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
 | |
| 
 | |
|   auto computation = m->AddEntryComputation(builder.Build());
 | |
|   auto outer_fusion = computation->CreateFusionInstruction(
 | |
|       {negate, add}, HloInstruction::FusionKind::kLoop);
 | |
|   HloInstruction* fused_negate = outer_fusion->fused_expression_root();
 | |
|   ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate);
 | |
|   outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
 | |
|       {fused_negate}, HloInstruction::FusionKind::kLoop);
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(), op::Fusion());
 | |
| 
 | |
|   EXPECT_EQ(2, FusionCount(m.get()));
 | |
|   EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie());
 | |
|   EXPECT_EQ(0, FusionCount(m.get()));
 | |
| 
 | |
|   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add()));
 | |
| }
 | |
| 
 | |
| }  // namespace
 | |
| }  // namespace xla
 |