243 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			243 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/hlo_reachability.h"
 | |
| 
 | |
| #include <set>
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | |
| #include "tensorflow/compiler/xla/test.h"
 | |
| #include "tensorflow/compiler/xla/test_helpers.h"
 | |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | |
| 
 | |
| namespace xla {
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| class HloReachabilityTest : public HloTestBase {};
 | |
| 
 | |
| TEST_F(HloReachabilityTest, Reachability) {
 | |
|   // Construct and test a reachability graph of the following form:
 | |
|   /*
 | |
|        a
 | |
|       / \
 | |
|      b   c
 | |
|       \ / \
 | |
|        d   e
 | |
|   */
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto a = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
 | |
|   auto b = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
 | |
|   auto c = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
 | |
|   auto d = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
 | |
|   auto e = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
 | |
|   auto module = CreateNewVerifiedModule();
 | |
|   module->AddEntryComputation(builder.Build());
 | |
| 
 | |
|   HloReachabilityMap reachability({a, b, c, d, e});
 | |
|   reachability.SetReachable(a, a);
 | |
|   EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, b));
 | |
|   EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, c));
 | |
|   EXPECT_TRUE(reachability.SetReachabilityToUnion({b, c}, d));
 | |
|   EXPECT_TRUE(reachability.SetReachabilityToUnion({c}, e));
 | |
| 
 | |
|   EXPECT_TRUE(reachability.IsReachable(a, a));
 | |
|   EXPECT_TRUE(reachability.IsReachable(a, b));
 | |
|   EXPECT_TRUE(reachability.IsReachable(a, c));
 | |
|   EXPECT_TRUE(reachability.IsReachable(a, d));
 | |
|   EXPECT_TRUE(reachability.IsReachable(a, e));
 | |
| 
 | |
|   EXPECT_FALSE(reachability.IsReachable(b, a));
 | |
|   EXPECT_TRUE(reachability.IsReachable(b, b));
 | |
|   EXPECT_FALSE(reachability.IsReachable(b, c));
 | |
|   EXPECT_TRUE(reachability.IsReachable(b, d));
 | |
|   EXPECT_FALSE(reachability.IsReachable(b, e));
 | |
| 
 | |
|   EXPECT_FALSE(reachability.IsReachable(e, a));
 | |
|   EXPECT_FALSE(reachability.IsReachable(e, b));
 | |
|   EXPECT_FALSE(reachability.IsReachable(e, c));
 | |
|   EXPECT_FALSE(reachability.IsReachable(e, d));
 | |
|   EXPECT_TRUE(reachability.IsReachable(e, e));
 | |
| 
 | |
|   // Recomputing the same reachability for a previously computed instruction
 | |
|   // should return false (no change).
 | |
|   EXPECT_FALSE(reachability.SetReachabilityToUnion({a}, b));
 | |
|   EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d));
 | |
| }
 | |
| 
 | |
| TEST_F(HloReachabilityTest, NonTrivialReachability) {
 | |
|   // Test reachability of a non-trivial computation:
 | |
|   //
 | |
|   // const1    const2
 | |
|   //    |         |
 | |
|   //    | +-------+
 | |
|   //    | |       |
 | |
|   //    add ..   negate
 | |
|   //     |   .     |
 | |
|   //     |   .... exp
 | |
|   //     |         |
 | |
|   //     +---+   +-+---+
 | |
|   //         |   |     |
 | |
|   //       multiply   copy
 | |
|   //
 | |
|   // There is a control dependency from 'add' to 'exp'.
 | |
|   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
 | |
|   auto builder = HloComputation::Builder(TestName());
 | |
|   auto constant1 = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
 | |
|   auto constant2 = builder.AddInstruction(
 | |
|       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
 | |
|   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
 | |
|       r0f32, HloOpcode::kAdd, constant1, constant2));
 | |
|   auto negate = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2));
 | |
|   auto exp = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate));
 | |
|   auto mul = builder.AddInstruction(
 | |
|       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp));
 | |
|   auto copy = builder.AddInstruction(
 | |
|       HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp));
 | |
| 
 | |
|   auto module = CreateNewVerifiedModule();
 | |
|   auto computation =
 | |
|       module->AddEntryComputation(builder.Build(/*root_instruction=*/mul));
 | |
| 
 | |
|   TF_CHECK_OK(add->AddControlDependencyTo(exp));
 | |
|   auto reachability = HloReachabilityMap::Build(computation);
 | |
| 
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, add));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant1, negate));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, exp));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, mul));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, copy));
 | |
| 
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, add));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, negate));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, exp));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, mul));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, copy));
 | |
| 
 | |
|   EXPECT_FALSE(reachability->IsReachable(exp, constant1));
 | |
|   EXPECT_FALSE(reachability->IsReachable(exp, constant2));
 | |
|   EXPECT_FALSE(reachability->IsReachable(exp, add));
 | |
|   EXPECT_FALSE(reachability->IsReachable(exp, negate));
 | |
|   EXPECT_TRUE(reachability->IsReachable(exp, exp));
 | |
|   EXPECT_TRUE(reachability->IsReachable(exp, mul));
 | |
|   EXPECT_TRUE(reachability->IsReachable(exp, copy));
 | |
| 
 | |
|   EXPECT_FALSE(reachability->IsReachable(mul, constant1));
 | |
|   EXPECT_FALSE(reachability->IsReachable(mul, constant2));
 | |
|   EXPECT_FALSE(reachability->IsReachable(mul, add));
 | |
|   EXPECT_FALSE(reachability->IsReachable(mul, negate));
 | |
|   EXPECT_FALSE(reachability->IsReachable(mul, exp));
 | |
|   EXPECT_TRUE(reachability->IsReachable(mul, mul));
 | |
|   EXPECT_FALSE(reachability->IsReachable(mul, copy));
 | |
| 
 | |
|   EXPECT_TRUE(reachability->IsConnected(constant1, copy));
 | |
|   EXPECT_TRUE(reachability->IsConnected(copy, constant1));
 | |
|   EXPECT_FALSE(reachability->IsConnected(negate, add));
 | |
|   EXPECT_FALSE(reachability->IsConnected(add, negate));
 | |
| 
 | |
|   // Remove the control dependency then update and verify the reachability map
 | |
|   ASSERT_IS_OK(add->RemoveControlDependencyTo(exp));
 | |
|   reachability->UpdateReachabilityThroughInstruction(exp);
 | |
| 
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, constant1));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant1, constant2));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, add));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant1, negate));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant1, exp));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant1, mul));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant1, copy));
 | |
| 
 | |
|   // Change a use within the graph then update and verify the reachability map
 | |
|   ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1));
 | |
|   reachability->UpdateReachabilityThroughInstruction(negate);
 | |
| 
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant2, constant1));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, constant2));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, add));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant2, negate));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant2, exp));
 | |
|   EXPECT_TRUE(reachability->IsReachable(constant2, mul));
 | |
|   EXPECT_FALSE(reachability->IsReachable(constant2, copy));
 | |
| }
 | |
| 
 | |
| TEST_F(HloReachabilityTest, ChannelReachability) {
 | |
|   const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
 | |
|   HloComputation::Builder builder("ChannelReachability");
 | |
|   auto param = builder.AddInstruction(
 | |
|       HloInstruction::CreateParameter(0, shape, "param"));
 | |
|   auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
 | |
|   auto send =
 | |
|       builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
 | |
|   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
 | |
|   auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
 | |
|   auto recv =
 | |
|       builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
 | |
|   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
 | |
| 
 | |
|   auto module = CreateNewVerifiedModule();
 | |
|   auto computation = module->AddEntryComputation(builder.Build(recv_done));
 | |
|   auto reachability = HloReachabilityMap::Build(computation);
 | |
|   EXPECT_TRUE(reachability->IsReachable(param, recv_done));
 | |
|   EXPECT_FALSE(reachability->IsReachable(send, recv));
 | |
|   EXPECT_FALSE(reachability->IsReachable(send_done, recv));
 | |
| }
 | |
| 
 | |
| TEST_F(HloReachabilityTest, ReplaceInstructions) {
 | |
|   auto module = ParseAndReturnVerifiedModule(R"(
 | |
|     HloModule test
 | |
| 
 | |
|     ENTRY entry {
 | |
|       p0 = f32[28,28]{1,0} parameter(0)
 | |
|       ROOT add = f32[28,28]{1,0} add(p0, p0)
 | |
|     })")
 | |
|                     .ValueOrDie();
 | |
|   auto computation = module->entry_computation();
 | |
|   auto reachability = HloReachabilityMap::Build(computation);
 | |
|   auto* add = module->entry_computation()->root_instruction();
 | |
|   auto* p0 = add->operand(0);
 | |
|   EXPECT_TRUE(reachability->IsReachable(p0, add));
 | |
| 
 | |
|   // Replacing an instruction with itself is a noop.
 | |
|   reachability->Replace(add, add);
 | |
|   EXPECT_TRUE(reachability->IsReachable(p0, add));
 | |
| 
 | |
|   // Introduce a fusion instruction taking the place of `add`.
 | |
|   auto* fusion = computation->AddInstruction(HloInstruction::CreateFusion(
 | |
|       add->shape(), HloInstruction::FusionKind::kLoop, add));
 | |
|   EXPECT_FALSE(reachability->IsPresent(fusion));
 | |
|   EXPECT_TRUE(reachability->IsReachable(p0, add));
 | |
| 
 | |
|   // Replace `add` with `fusion` in the readability map.
 | |
|   reachability->Replace(add, fusion);
 | |
|   EXPECT_FALSE(reachability->IsPresent(add));
 | |
|   EXPECT_TRUE(reachability->IsReachable(p0, fusion));
 | |
| }
 | |
| 
 | |
| }  // namespace
 | |
| 
 | |
| }  // namespace xla
 |