diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 1fe758e643d..aa18b554976 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1323,22 +1323,24 @@ Status BufferAssigner::AssignPresetBuffers( } const HloAliasAnalysis& alias_analysis = assignment->alias_analysis(); + const HloDataflowAnalysis& dataflow_analysis = + alias_analysis.dataflow_analysis(); for (auto& position_and_chunk : preset_assignments_->chunks()) { const HloPosition& position = position_and_chunk.first; - const HloBuffer& buffer = - alias_analysis.GetUniqueBufferAt(position.instruction, position.index); - VLOG(3) << "Preset allocation for buffer: " << buffer; + const HloValue& value = dataflow_analysis.GetUniqueValueAt( + position.instruction, position.index); + VLOG(3) << "Preset allocation for value: " << value.ToShortString(); const HeapSimulator::Chunk& chunk = position_and_chunk.second; - auto preset_allocations_iter = preset_allocations.find(buffer.color()); + auto preset_allocations_iter = preset_allocations.find(value.color()); CHECK(preset_allocations_iter != preset_allocations.end()) - << "No preset buffer allocation for color " << buffer.color() + << "No preset value allocation for color " << value.color() << " found."; - preset_allocations_iter->second->AddAssignment(buffer.GetUniqueValue(), - chunk.offset, chunk.size); - // Ensure that there is at most one preset allocation for each buffer. - CHECK_EQ(assigned_buffers->count(&buffer), 0); - assigned_buffers->emplace(&buffer); + preset_allocations_iter->second->AddAssignment(value, chunk.offset, + chunk.size); + + const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value); + assigned_buffers->insert(&buffer); } // Upon consumption of the preset assignments, delete it so that if this diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 56e0630eedf..2a02d3776ce 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -566,10 +566,10 @@ class BufferAssigner { static Colorer DefaultColorer() { return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { for (HloValue* value : alias_analysis->dataflow_analysis().values()) { - HloInstruction* defining_instruction = value->defining_instruction(); - if (defining_instruction->shape().has_layout()) { + const HloPosition& defining_position = value->defining_position(); + if (defining_position.shape().has_layout()) { value->set_color(BufferValue::Color( - defining_instruction->shape().layout().memory_space())); + defining_position.shape().layout().memory_space())); } else { value->set_color(BufferValue::Color(0)); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8ea38aa5a1e..e54ad852d44 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -769,6 +769,94 @@ TEST_F(BufferAssignmentTest, PresetAssignments) { GetAssignedOutputAllocation(*buffers, sub); } +TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) { + // Tests preset assignments when there is no 1-to-1 corrspondance between + // HloValue and HloBuffer (i.e., a while loop). + auto module = CreateNewVerifiedModule(); + Shape f32vec10_color1 = + ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1); + Shape t_s32_f32v10_color1 = + ShapeUtil::MakeTupleShape({s32_, f32vec10_color1}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(s32_, cond_param, 0)); + HloInstruction* cond_limit = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(50))); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_limit, ComparisonDirection::kLt)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param")); + HloInstruction* body_iter = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(s32_, body_param, 0)); + HloInstruction* body_data = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1)); + HloInstruction* body_data_increment = body_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}))); + HloInstruction* body_data_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment)); + HloInstruction* body_iter_increment = body_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction* body_iter_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + s32_, HloOpcode::kAdd, body_iter, body_iter_increment)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_iter_next, body_data_next})); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, s32_, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec10_, "param_data")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, negate})); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + t_s32_f32v10_color1, cond_computation, body_computation, tuple)); + HloInstruction* while_data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1)); + builder.AddInstruction(HloInstruction::CreateBinary( + f32vec10_, HloOpcode::kAdd, while_data, data)); + module->AddEntryComputation(builder.Build()); + + // Set only one preset assignment for while data and its aliases. + auto preset_assignments = absl::make_unique(); + preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40}); + preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40}); + preset_assignments->add_chunk({cond_param, {1}}, + {/*offset=*/100, /*size=*/40}); + preset_assignments->add_chunk({body_param, {1}}, + {/*offset=*/100, /*size=*/40}); + preset_assignments->add_chunk({body_data_next, {}}, + {/*offset=*/100, /*size=*/40}); + preset_assignments->add_size(/*memory_space=*/1, /*size=*/140); + + auto buffers = RunBufferAssignmentWithPresetAssignments( + module.get(), std::move(preset_assignments)); + + // All assigned buffers are aliased so they should have the same offset and + // size. + const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate); + EXPECT_EQ(data_buffer.assigned_buffers().size(), 5); + for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) { + EXPECT_EQ(value_and_offsetsize.second.offset, 100); + EXPECT_EQ(value_and_offsetsize.second.size, 40); + EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1)); + } +} + TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // This is similar to the Basic test, with the difference that (sub) is // another user of (mul)'s result, so (mul)'s buffer cannot be reused for