From eb0153ec1f09cc4bc55026b420e3084287ea5e11 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Fri, 1 Nov 2019 15:46:16 -0700 Subject: [PATCH] [XLA] Fix live range flattening for while loops. We were incrementing time one too many times when descending to while instructions, causing flattened_instruction_sequence() to have different logical times than instruction_schedule(). PiperOrigin-RevId: 278015948 Change-Id: If71f032a3c1e21210d770556c4d65c5f2a66b01a --- .../compiler/xla/service/hlo_live_range.cc | 1 - .../xla/service/hlo_live_range_test.cc | 99 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_live_range.cc b/tensorflow/compiler/xla/service/hlo_live_range.cc index 8ec437ec250..5f3808a1a01 100644 --- a/tensorflow/compiler/xla/service/hlo_live_range.cc +++ b/tensorflow/compiler/xla/service/hlo_live_range.cc @@ -110,7 +110,6 @@ int64 HloLiveRange::FlattenSchedule(const HloComputation& computation, } if (instruction->opcode() == HloOpcode::kWhile) { time = FlattenSchedule(*instruction->while_condition(), time); - time++; time = FlattenSchedule(*instruction->while_body(), time); } } diff --git a/tensorflow/compiler/xla/service/hlo_live_range_test.cc b/tensorflow/compiler/xla/service/hlo_live_range_test.cc index d524d9f0c82..232c6b95e88 100644 --- a/tensorflow/compiler/xla/service/hlo_live_range_test.cc +++ b/tensorflow/compiler/xla/service/hlo_live_range_test.cc @@ -65,6 +65,23 @@ class HloLiveRangeTest : public HloTestBase { auto* value = BufferAt(instruction, index); return hlo_live_range_->buffer_live_ranges().at(value); } + + // Checks if the logical times reported by instruction_schedule() matches + // flattened_instruction_sequence(). + void CheckSchedule() const { + const auto& flattened_instructions = + hlo_live_range_->flattened_instruction_sequence().instructions(); + EXPECT_EQ(flattened_instructions.size(), + hlo_live_range_->instruction_schedule().size()); + for (const auto& inst_and_time : hlo_live_range_->instruction_schedule()) { + EXPECT_EQ(flattened_instructions.at(inst_and_time.second), + inst_and_time.first) + << "(flattened_inst[" << inst_and_time.second + << "] = " << flattened_instructions.at(inst_and_time.second)->name() + << ") != (inst_schedule[" << inst_and_time.second + << "] = " << inst_and_time.first->name() << ")"; + } + } }; TEST_F(HloLiveRangeTest, Multiply) { @@ -83,6 +100,8 @@ TEST_F(HloLiveRangeTest, Multiply) { Analyze(schedule); + CheckSchedule(); + // Parameters live from beginning to end. EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 3})); EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 3})); @@ -111,6 +130,8 @@ TEST_F(HloLiveRangeTest, MultiplyAdd) { Analyze(schedule); + CheckSchedule(); + // Parameters live from beginning to end. EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 5})); EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5})); @@ -146,6 +167,8 @@ TEST_F(HloLiveRangeTest, LiveOutBuffers) { Analyze(schedule); + CheckSchedule(); + // Parameters live from beginning to end. EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 6})); EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 6})); @@ -184,6 +207,8 @@ TEST_F(HloLiveRangeTest, InstructionScheduledAfterRoot) { Analyze(schedule); + CheckSchedule(); + // Parameters live from beginning to end. EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 7})); EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 7})); @@ -222,6 +247,8 @@ TEST_F(HloLiveRangeTest, AliasedParameter) { Analyze(schedule); + CheckSchedule(); + // Non-readonly parameter live like other normal buffers. EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 2})); @@ -235,5 +262,77 @@ TEST_F(HloLiveRangeTest, AliasedParameter) { EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5})); } +TEST_F(HloLiveRangeTest, While) { + Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); + Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_limit = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(50.f))); + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_limit, ComparisonDirection::kLt)); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloInstruction* body_iter = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); + HloInstruction* body_data = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + HloInstruction* body_iter_increment = body_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.f))); + HloInstruction* body_iter_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); + HloInstruction* body_data_increment = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}}))); + HloInstruction* body_data_mul = + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kMultiply, body_data, body_data)); + HloInstruction* body_data_add = + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, body_data, body_data_increment)); + HloInstruction* body_data_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, body_data_add, body_data_mul)); + HloInstruction* body_out = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_data_next, body_iter_next})); + HloComputation* body_computation = + module_->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param_iter")); + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({data, iter})); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module_->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module_.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_limit, cond_lt}); + schedule.set_sequence(body_computation, + {body_param, body_iter, body_data, body_iter_increment, + body_iter_next, body_data_increment, body_data_mul, + body_data_add, body_data_next, body_out}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + + Analyze(schedule); + + CheckSchedule(); +} } // namespace } // namespace xla