[XLA] Fix live range flattening for while loops.

We were incrementing time one too many times when descending to while
instructions, causing flattened_instruction_sequence() to have different logical
times than instruction_schedule().

PiperOrigin-RevId: 278015948
Change-Id: If71f032a3c1e21210d770556c4d65c5f2a66b01a
This commit is contained in:
Berkin Ilbeyi 2019-11-01 15:46:16 -07:00 committed by TensorFlower Gardener
parent edb5ff7234
commit eb0153ec1f
2 changed files with 99 additions and 1 deletions

View File

@ -110,7 +110,6 @@ int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
}
if (instruction->opcode() == HloOpcode::kWhile) {
time = FlattenSchedule(*instruction->while_condition(), time);
time++;
time = FlattenSchedule(*instruction->while_body(), time);
}
}

View File

@ -65,6 +65,23 @@ class HloLiveRangeTest : public HloTestBase {
auto* value = BufferAt(instruction, index);
return hlo_live_range_->buffer_live_ranges().at(value);
}
// Checks if the logical times reported by instruction_schedule() matches
// flattened_instruction_sequence().
void CheckSchedule() const {
const auto& flattened_instructions =
hlo_live_range_->flattened_instruction_sequence().instructions();
EXPECT_EQ(flattened_instructions.size(),
hlo_live_range_->instruction_schedule().size());
for (const auto& inst_and_time : hlo_live_range_->instruction_schedule()) {
EXPECT_EQ(flattened_instructions.at(inst_and_time.second),
inst_and_time.first)
<< "(flattened_inst[" << inst_and_time.second
<< "] = " << flattened_instructions.at(inst_and_time.second)->name()
<< ") != (inst_schedule[" << inst_and_time.second
<< "] = " << inst_and_time.first->name() << ")";
}
}
};
TEST_F(HloLiveRangeTest, Multiply) {
@ -83,6 +100,8 @@ TEST_F(HloLiveRangeTest, Multiply) {
Analyze(schedule);
CheckSchedule();
// Parameters live from beginning to end.
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 3}));
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 3}));
@ -111,6 +130,8 @@ TEST_F(HloLiveRangeTest, MultiplyAdd) {
Analyze(schedule);
CheckSchedule();
// Parameters live from beginning to end.
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 5}));
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5}));
@ -146,6 +167,8 @@ TEST_F(HloLiveRangeTest, LiveOutBuffers) {
Analyze(schedule);
CheckSchedule();
// Parameters live from beginning to end.
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 6}));
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 6}));
@ -184,6 +207,8 @@ TEST_F(HloLiveRangeTest, InstructionScheduledAfterRoot) {
Analyze(schedule);
CheckSchedule();
// Parameters live from beginning to end.
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 7}));
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 7}));
@ -222,6 +247,8 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
Analyze(schedule);
CheckSchedule();
// Non-readonly parameter live like other normal buffers.
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 2}));
@ -235,5 +262,77 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5}));
}
TEST_F(HloLiveRangeTest, While) {
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_limit = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_limit, ComparisonDirection::kLt));
HloComputation* cond_computation =
module_->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloInstruction* body_iter = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
HloInstruction* body_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
HloInstruction* body_iter_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
HloInstruction* body_iter_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
HloInstruction* body_data_increment =
body_builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
HloInstruction* body_data_mul =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kMultiply, body_data, body_data));
HloInstruction* body_data_add =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, body_data, body_data_increment));
HloInstruction* body_data_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, body_data_add, body_data_mul));
HloInstruction* body_out = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_data_next, body_iter_next}));
HloComputation* body_computation =
module_->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param_iter"));
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloComputation* entry_computation =
module_->AddEntryComputation(builder.Build());
HloSchedule schedule(module_.get());
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_limit, cond_lt});
schedule.set_sequence(body_computation,
{body_param, body_iter, body_data, body_iter_increment,
body_iter_next, body_data_increment, body_data_mul,
body_data_add, body_data_next, body_out});
schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
Analyze(schedule);
CheckSchedule();
}
} // namespace
} // namespace xla