[XLA] Fix live range flattening for while loops.
We were incrementing time one too many times when descending to while instructions, causing flattened_instruction_sequence() to have different logical times than instruction_schedule(). PiperOrigin-RevId: 278015948 Change-Id: If71f032a3c1e21210d770556c4d65c5f2a66b01a
This commit is contained in:
		
							parent
							
								
									edb5ff7234
								
							
						
					
					
						commit
						eb0153ec1f
					
				@ -110,7 +110,6 @@ int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
 | 
			
		||||
      }
 | 
			
		||||
      if (instruction->opcode() == HloOpcode::kWhile) {
 | 
			
		||||
        time = FlattenSchedule(*instruction->while_condition(), time);
 | 
			
		||||
        time++;
 | 
			
		||||
        time = FlattenSchedule(*instruction->while_body(), time);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -65,6 +65,23 @@ class HloLiveRangeTest : public HloTestBase {
 | 
			
		||||
    auto* value = BufferAt(instruction, index);
 | 
			
		||||
    return hlo_live_range_->buffer_live_ranges().at(value);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Checks if the logical times reported by instruction_schedule() matches
 | 
			
		||||
  // flattened_instruction_sequence().
 | 
			
		||||
  void CheckSchedule() const {
 | 
			
		||||
    const auto& flattened_instructions =
 | 
			
		||||
        hlo_live_range_->flattened_instruction_sequence().instructions();
 | 
			
		||||
    EXPECT_EQ(flattened_instructions.size(),
 | 
			
		||||
              hlo_live_range_->instruction_schedule().size());
 | 
			
		||||
    for (const auto& inst_and_time : hlo_live_range_->instruction_schedule()) {
 | 
			
		||||
      EXPECT_EQ(flattened_instructions.at(inst_and_time.second),
 | 
			
		||||
                inst_and_time.first)
 | 
			
		||||
          << "(flattened_inst[" << inst_and_time.second
 | 
			
		||||
          << "] = " << flattened_instructions.at(inst_and_time.second)->name()
 | 
			
		||||
          << ") != (inst_schedule[" << inst_and_time.second
 | 
			
		||||
          << "] = " << inst_and_time.first->name() << ")";
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(HloLiveRangeTest, Multiply) {
 | 
			
		||||
@ -83,6 +100,8 @@ TEST_F(HloLiveRangeTest, Multiply) {
 | 
			
		||||
 | 
			
		||||
  Analyze(schedule);
 | 
			
		||||
 | 
			
		||||
  CheckSchedule();
 | 
			
		||||
 | 
			
		||||
  // Parameters live from beginning to end.
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 3}));
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 3}));
 | 
			
		||||
@ -111,6 +130,8 @@ TEST_F(HloLiveRangeTest, MultiplyAdd) {
 | 
			
		||||
 | 
			
		||||
  Analyze(schedule);
 | 
			
		||||
 | 
			
		||||
  CheckSchedule();
 | 
			
		||||
 | 
			
		||||
  // Parameters live from beginning to end.
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 5}));
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5}));
 | 
			
		||||
@ -146,6 +167,8 @@ TEST_F(HloLiveRangeTest, LiveOutBuffers) {
 | 
			
		||||
 | 
			
		||||
  Analyze(schedule);
 | 
			
		||||
 | 
			
		||||
  CheckSchedule();
 | 
			
		||||
 | 
			
		||||
  // Parameters live from beginning to end.
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 6}));
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 6}));
 | 
			
		||||
@ -184,6 +207,8 @@ TEST_F(HloLiveRangeTest, InstructionScheduledAfterRoot) {
 | 
			
		||||
 | 
			
		||||
  Analyze(schedule);
 | 
			
		||||
 | 
			
		||||
  CheckSchedule();
 | 
			
		||||
 | 
			
		||||
  // Parameters live from beginning to end.
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 7}));
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 7}));
 | 
			
		||||
@ -222,6 +247,8 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
 | 
			
		||||
 | 
			
		||||
  Analyze(schedule);
 | 
			
		||||
 | 
			
		||||
  CheckSchedule();
 | 
			
		||||
 | 
			
		||||
  // Non-readonly parameter live like other normal buffers.
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 2}));
 | 
			
		||||
 | 
			
		||||
@ -235,5 +262,77 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
 | 
			
		||||
  EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloLiveRangeTest, While) {
 | 
			
		||||
  Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
 | 
			
		||||
  Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
 | 
			
		||||
  Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
 | 
			
		||||
 | 
			
		||||
  auto cond_builder = HloComputation::Builder("WhileCond");
 | 
			
		||||
  HloInstruction* cond_param = cond_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
 | 
			
		||||
  HloInstruction* cond_iter = cond_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
 | 
			
		||||
  HloInstruction* cond_limit = cond_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
 | 
			
		||||
  HloInstruction* cond_lt = cond_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
 | 
			
		||||
                                    cond_limit, ComparisonDirection::kLt));
 | 
			
		||||
  HloComputation* cond_computation =
 | 
			
		||||
      module_->AddEmbeddedComputation(cond_builder.Build());
 | 
			
		||||
 | 
			
		||||
  auto body_builder = HloComputation::Builder("WhileBody");
 | 
			
		||||
  HloInstruction* body_param = body_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
 | 
			
		||||
  HloInstruction* body_iter = body_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
 | 
			
		||||
  HloInstruction* body_data = body_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateGetTupleElement(shape, body_param, 0));
 | 
			
		||||
  HloInstruction* body_iter_increment = body_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
 | 
			
		||||
  HloInstruction* body_iter_next =
 | 
			
		||||
      body_builder.AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
          scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
 | 
			
		||||
  HloInstruction* body_data_increment =
 | 
			
		||||
      body_builder.AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
          LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
 | 
			
		||||
  HloInstruction* body_data_mul =
 | 
			
		||||
      body_builder.AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
          shape, HloOpcode::kMultiply, body_data, body_data));
 | 
			
		||||
  HloInstruction* body_data_add =
 | 
			
		||||
      body_builder.AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
          shape, HloOpcode::kAdd, body_data, body_data_increment));
 | 
			
		||||
  HloInstruction* body_data_next =
 | 
			
		||||
      body_builder.AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
          shape, HloOpcode::kAdd, body_data_add, body_data_mul));
 | 
			
		||||
  HloInstruction* body_out = body_builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateTuple({body_data_next, body_iter_next}));
 | 
			
		||||
  HloComputation* body_computation =
 | 
			
		||||
      module_->AddEmbeddedComputation(body_builder.Build());
 | 
			
		||||
 | 
			
		||||
  auto builder = HloComputation::Builder(TestName());
 | 
			
		||||
  HloInstruction* data = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateParameter(0, shape, "param_iter"));
 | 
			
		||||
  HloInstruction* iter = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
 | 
			
		||||
  HloInstruction* tuple =
 | 
			
		||||
      builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
 | 
			
		||||
  HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
 | 
			
		||||
      tuple_shape, cond_computation, body_computation, tuple));
 | 
			
		||||
  HloComputation* entry_computation =
 | 
			
		||||
      module_->AddEntryComputation(builder.Build());
 | 
			
		||||
 | 
			
		||||
  HloSchedule schedule(module_.get());
 | 
			
		||||
  schedule.set_sequence(cond_computation,
 | 
			
		||||
                        {cond_param, cond_iter, cond_limit, cond_lt});
 | 
			
		||||
  schedule.set_sequence(body_computation,
 | 
			
		||||
                        {body_param, body_iter, body_data, body_iter_increment,
 | 
			
		||||
                         body_iter_next, body_data_increment, body_data_mul,
 | 
			
		||||
                         body_data_add, body_data_next, body_out});
 | 
			
		||||
  schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
 | 
			
		||||
 | 
			
		||||
  Analyze(schedule);
 | 
			
		||||
 | 
			
		||||
  CheckSchedule();
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user