[XLA] Use dataflow analysis instead of alias analysis for setting preset assmts.
We were previously using alias analysis to find all aliased HloValues for a given HloBuffer for the HloPosition, which led to allocating the same buffer multiple times. Instead, use dataflow analysis to get only the unique HloValue. This is because there is one-to-one correspondance between HloPosition and HloValue, but multiple different HloPositions can map to the same HloBuffer in case we have aliasing due to while loops. PiperOrigin-RevId: 280318809 Change-Id: I9c674911de3cb8582c078d7accedf627507e12ca
This commit is contained in:
parent
12fda6ebfe
commit
f95bd6ec17
@ -1323,22 +1323,24 @@ Status BufferAssigner::AssignPresetBuffers(
|
||||
}
|
||||
|
||||
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
|
||||
const HloDataflowAnalysis& dataflow_analysis =
|
||||
alias_analysis.dataflow_analysis();
|
||||
|
||||
for (auto& position_and_chunk : preset_assignments_->chunks()) {
|
||||
const HloPosition& position = position_and_chunk.first;
|
||||
const HloBuffer& buffer =
|
||||
alias_analysis.GetUniqueBufferAt(position.instruction, position.index);
|
||||
VLOG(3) << "Preset allocation for buffer: " << buffer;
|
||||
const HloValue& value = dataflow_analysis.GetUniqueValueAt(
|
||||
position.instruction, position.index);
|
||||
VLOG(3) << "Preset allocation for value: " << value.ToShortString();
|
||||
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
|
||||
auto preset_allocations_iter = preset_allocations.find(buffer.color());
|
||||
auto preset_allocations_iter = preset_allocations.find(value.color());
|
||||
CHECK(preset_allocations_iter != preset_allocations.end())
|
||||
<< "No preset buffer allocation for color " << buffer.color()
|
||||
<< "No preset value allocation for color " << value.color()
|
||||
<< " found.";
|
||||
preset_allocations_iter->second->AddAssignment(buffer.GetUniqueValue(),
|
||||
chunk.offset, chunk.size);
|
||||
// Ensure that there is at most one preset allocation for each buffer.
|
||||
CHECK_EQ(assigned_buffers->count(&buffer), 0);
|
||||
assigned_buffers->emplace(&buffer);
|
||||
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
|
||||
chunk.size);
|
||||
|
||||
const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
|
||||
assigned_buffers->insert(&buffer);
|
||||
}
|
||||
|
||||
// Upon consumption of the preset assignments, delete it so that if this
|
||||
|
@ -566,10 +566,10 @@ class BufferAssigner {
|
||||
static Colorer DefaultColorer() {
|
||||
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
|
||||
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
|
||||
HloInstruction* defining_instruction = value->defining_instruction();
|
||||
if (defining_instruction->shape().has_layout()) {
|
||||
const HloPosition& defining_position = value->defining_position();
|
||||
if (defining_position.shape().has_layout()) {
|
||||
value->set_color(BufferValue::Color(
|
||||
defining_instruction->shape().layout().memory_space()));
|
||||
defining_position.shape().layout().memory_space()));
|
||||
} else {
|
||||
value->set_color(BufferValue::Color(0));
|
||||
}
|
||||
|
@ -769,6 +769,94 @@ TEST_F(BufferAssignmentTest, PresetAssignments) {
|
||||
GetAssignedOutputAllocation(*buffers, sub);
|
||||
}
|
||||
|
||||
TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
|
||||
// Tests preset assignments when there is no 1-to-1 corrspondance between
|
||||
// HloValue and HloBuffer (i.e., a while loop).
|
||||
auto module = CreateNewVerifiedModule();
|
||||
Shape f32vec10_color1 =
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
|
||||
Shape t_s32_f32v10_color1 =
|
||||
ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
|
||||
|
||||
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
|
||||
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
|
||||
HloInstruction* cond_limit = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||
cond_limit, ComparisonDirection::kLt));
|
||||
HloComputation* cond_computation =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto body_builder = HloComputation::Builder("WhileBody");
|
||||
HloInstruction* body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
|
||||
HloInstruction* body_iter = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
|
||||
HloInstruction* body_data = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
|
||||
HloInstruction* body_data_increment = body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
|
||||
HloInstruction* body_data_next =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
|
||||
HloInstruction* body_iter_increment = body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
|
||||
HloInstruction* body_iter_next =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
|
||||
body_builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({body_iter_next, body_data_next}));
|
||||
HloComputation* body_computation =
|
||||
module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* iter = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, s32_, "param_iter"));
|
||||
HloInstruction* data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
|
||||
HloInstruction* negate = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
|
||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
t_s32_f32v10_color1, cond_computation, body_computation, tuple));
|
||||
HloInstruction* while_data = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
f32vec10_, HloOpcode::kAdd, while_data, data));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Set only one preset assignment for while data and its aliases.
|
||||
auto preset_assignments = absl::make_unique<PresetAssignments>();
|
||||
preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
|
||||
preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40});
|
||||
preset_assignments->add_chunk({cond_param, {1}},
|
||||
{/*offset=*/100, /*size=*/40});
|
||||
preset_assignments->add_chunk({body_param, {1}},
|
||||
{/*offset=*/100, /*size=*/40});
|
||||
preset_assignments->add_chunk({body_data_next, {}},
|
||||
{/*offset=*/100, /*size=*/40});
|
||||
preset_assignments->add_size(/*memory_space=*/1, /*size=*/140);
|
||||
|
||||
auto buffers = RunBufferAssignmentWithPresetAssignments(
|
||||
module.get(), std::move(preset_assignments));
|
||||
|
||||
// All assigned buffers are aliased so they should have the same offset and
|
||||
// size.
|
||||
const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
|
||||
EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
|
||||
for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
|
||||
EXPECT_EQ(value_and_offsetsize.second.offset, 100);
|
||||
EXPECT_EQ(value_and_offsetsize.second.size, 40);
|
||||
EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
|
||||
// This is similar to the Basic test, with the difference that (sub) is
|
||||
// another user of (mul)'s result, so (mul)'s buffer cannot be reused for
|
||||
|
Loading…
Reference in New Issue
Block a user