[XLA] Fix live range flattening for while loops.
We were incrementing time one too many times when descending to while instructions, causing flattened_instruction_sequence() to have different logical times than instruction_schedule(). PiperOrigin-RevId: 278015948 Change-Id: If71f032a3c1e21210d770556c4d65c5f2a66b01a
This commit is contained in:
parent
edb5ff7234
commit
eb0153ec1f
@ -110,7 +110,6 @@ int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
time = FlattenSchedule(*instruction->while_condition(), time);
|
||||
time++;
|
||||
time = FlattenSchedule(*instruction->while_body(), time);
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,6 +65,23 @@ class HloLiveRangeTest : public HloTestBase {
|
||||
auto* value = BufferAt(instruction, index);
|
||||
return hlo_live_range_->buffer_live_ranges().at(value);
|
||||
}
|
||||
|
||||
// Checks if the logical times reported by instruction_schedule() matches
|
||||
// flattened_instruction_sequence().
|
||||
void CheckSchedule() const {
|
||||
const auto& flattened_instructions =
|
||||
hlo_live_range_->flattened_instruction_sequence().instructions();
|
||||
EXPECT_EQ(flattened_instructions.size(),
|
||||
hlo_live_range_->instruction_schedule().size());
|
||||
for (const auto& inst_and_time : hlo_live_range_->instruction_schedule()) {
|
||||
EXPECT_EQ(flattened_instructions.at(inst_and_time.second),
|
||||
inst_and_time.first)
|
||||
<< "(flattened_inst[" << inst_and_time.second
|
||||
<< "] = " << flattened_instructions.at(inst_and_time.second)->name()
|
||||
<< ") != (inst_schedule[" << inst_and_time.second
|
||||
<< "] = " << inst_and_time.first->name() << ")";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(HloLiveRangeTest, Multiply) {
|
||||
@ -83,6 +100,8 @@ TEST_F(HloLiveRangeTest, Multiply) {
|
||||
|
||||
Analyze(schedule);
|
||||
|
||||
CheckSchedule();
|
||||
|
||||
// Parameters live from beginning to end.
|
||||
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 3}));
|
||||
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 3}));
|
||||
@ -111,6 +130,8 @@ TEST_F(HloLiveRangeTest, MultiplyAdd) {
|
||||
|
||||
Analyze(schedule);
|
||||
|
||||
CheckSchedule();
|
||||
|
||||
// Parameters live from beginning to end.
|
||||
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 5}));
|
||||
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5}));
|
||||
@ -146,6 +167,8 @@ TEST_F(HloLiveRangeTest, LiveOutBuffers) {
|
||||
|
||||
Analyze(schedule);
|
||||
|
||||
CheckSchedule();
|
||||
|
||||
// Parameters live from beginning to end.
|
||||
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 6}));
|
||||
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 6}));
|
||||
@ -184,6 +207,8 @@ TEST_F(HloLiveRangeTest, InstructionScheduledAfterRoot) {
|
||||
|
||||
Analyze(schedule);
|
||||
|
||||
CheckSchedule();
|
||||
|
||||
// Parameters live from beginning to end.
|
||||
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 7}));
|
||||
EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 7}));
|
||||
@ -222,6 +247,8 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
|
||||
|
||||
Analyze(schedule);
|
||||
|
||||
CheckSchedule();
|
||||
|
||||
// Non-readonly parameter live like other normal buffers.
|
||||
EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 2}));
|
||||
|
||||
@ -235,5 +262,77 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
|
||||
EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5}));
|
||||
}
|
||||
|
||||
TEST_F(HloLiveRangeTest, While) {
|
||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
|
||||
|
||||
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
|
||||
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
HloInstruction* cond_limit = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||
cond_limit, ComparisonDirection::kLt));
|
||||
HloComputation* cond_computation =
|
||||
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto body_builder = HloComputation::Builder("WhileBody");
|
||||
HloInstruction* body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
|
||||
HloInstruction* body_iter = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
|
||||
HloInstruction* body_data = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
|
||||
HloInstruction* body_iter_increment = body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
|
||||
HloInstruction* body_iter_next =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
|
||||
HloInstruction* body_data_increment =
|
||||
body_builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
|
||||
HloInstruction* body_data_mul =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kMultiply, body_data, body_data));
|
||||
HloInstruction* body_data_add =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, body_data, body_data_increment));
|
||||
HloInstruction* body_data_next =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, body_data_add, body_data_mul));
|
||||
HloInstruction* body_out = body_builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({body_data_next, body_iter_next}));
|
||||
HloComputation* body_computation =
|
||||
module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param_iter"));
|
||||
HloInstruction* iter = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
|
||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
tuple_shape, cond_computation, body_computation, tuple));
|
||||
HloComputation* entry_computation =
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module_.get());
|
||||
schedule.set_sequence(cond_computation,
|
||||
{cond_param, cond_iter, cond_limit, cond_lt});
|
||||
schedule.set_sequence(body_computation,
|
||||
{body_param, body_iter, body_data, body_iter_increment,
|
||||
body_iter_next, body_data_increment, body_data_mul,
|
||||
body_data_add, body_data_next, body_out});
|
||||
schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
|
||||
|
||||
Analyze(schedule);
|
||||
|
||||
CheckSchedule();
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user