[XLA] Use dataflow analysis instead of alias analysis for setting preset assmts.
We were previously using alias analysis to find all aliased HloValues for a given HloBuffer for the HloPosition, which led to allocating the same buffer multiple times. Instead, use dataflow analysis to get only the unique HloValue. This is because there is one-to-one correspondance between HloPosition and HloValue, but multiple different HloPositions can map to the same HloBuffer in case we have aliasing due to while loops. PiperOrigin-RevId: 280318809 Change-Id: I9c674911de3cb8582c078d7accedf627507e12ca
This commit is contained in:
parent
12fda6ebfe
commit
f95bd6ec17
@ -1323,22 +1323,24 @@ Status BufferAssigner::AssignPresetBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
|
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
|
||||||
|
const HloDataflowAnalysis& dataflow_analysis =
|
||||||
|
alias_analysis.dataflow_analysis();
|
||||||
|
|
||||||
for (auto& position_and_chunk : preset_assignments_->chunks()) {
|
for (auto& position_and_chunk : preset_assignments_->chunks()) {
|
||||||
const HloPosition& position = position_and_chunk.first;
|
const HloPosition& position = position_and_chunk.first;
|
||||||
const HloBuffer& buffer =
|
const HloValue& value = dataflow_analysis.GetUniqueValueAt(
|
||||||
alias_analysis.GetUniqueBufferAt(position.instruction, position.index);
|
position.instruction, position.index);
|
||||||
VLOG(3) << "Preset allocation for buffer: " << buffer;
|
VLOG(3) << "Preset allocation for value: " << value.ToShortString();
|
||||||
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
|
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
|
||||||
auto preset_allocations_iter = preset_allocations.find(buffer.color());
|
auto preset_allocations_iter = preset_allocations.find(value.color());
|
||||||
CHECK(preset_allocations_iter != preset_allocations.end())
|
CHECK(preset_allocations_iter != preset_allocations.end())
|
||||||
<< "No preset buffer allocation for color " << buffer.color()
|
<< "No preset value allocation for color " << value.color()
|
||||||
<< " found.";
|
<< " found.";
|
||||||
preset_allocations_iter->second->AddAssignment(buffer.GetUniqueValue(),
|
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
|
||||||
chunk.offset, chunk.size);
|
chunk.size);
|
||||||
// Ensure that there is at most one preset allocation for each buffer.
|
|
||||||
CHECK_EQ(assigned_buffers->count(&buffer), 0);
|
const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
|
||||||
assigned_buffers->emplace(&buffer);
|
assigned_buffers->insert(&buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upon consumption of the preset assignments, delete it so that if this
|
// Upon consumption of the preset assignments, delete it so that if this
|
||||||
|
@ -566,10 +566,10 @@ class BufferAssigner {
|
|||||||
static Colorer DefaultColorer() {
|
static Colorer DefaultColorer() {
|
||||||
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
|
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
|
||||||
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
|
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
|
||||||
HloInstruction* defining_instruction = value->defining_instruction();
|
const HloPosition& defining_position = value->defining_position();
|
||||||
if (defining_instruction->shape().has_layout()) {
|
if (defining_position.shape().has_layout()) {
|
||||||
value->set_color(BufferValue::Color(
|
value->set_color(BufferValue::Color(
|
||||||
defining_instruction->shape().layout().memory_space()));
|
defining_position.shape().layout().memory_space()));
|
||||||
} else {
|
} else {
|
||||||
value->set_color(BufferValue::Color(0));
|
value->set_color(BufferValue::Color(0));
|
||||||
}
|
}
|
||||||
|
@ -769,6 +769,94 @@ TEST_F(BufferAssignmentTest, PresetAssignments) {
|
|||||||
GetAssignedOutputAllocation(*buffers, sub);
|
GetAssignedOutputAllocation(*buffers, sub);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
|
||||||
|
// Tests preset assignments when there is no 1-to-1 corrspondance between
|
||||||
|
// HloValue and HloBuffer (i.e., a while loop).
|
||||||
|
auto module = CreateNewVerifiedModule();
|
||||||
|
Shape f32vec10_color1 =
|
||||||
|
ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
|
||||||
|
Shape t_s32_f32v10_color1 =
|
||||||
|
ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
|
||||||
|
|
||||||
|
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||||
|
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
|
||||||
|
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
|
||||||
|
HloInstruction* cond_limit = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
|
||||||
|
cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||||
|
cond_limit, ComparisonDirection::kLt));
|
||||||
|
HloComputation* cond_computation =
|
||||||
|
module->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
|
auto body_builder = HloComputation::Builder("WhileBody");
|
||||||
|
HloInstruction* body_param = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
|
||||||
|
HloInstruction* body_iter = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
|
||||||
|
HloInstruction* body_data = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
|
||||||
|
HloInstruction* body_data_increment = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
|
||||||
|
HloInstruction* body_data_next =
|
||||||
|
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
|
||||||
|
HloInstruction* body_iter_increment = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
|
||||||
|
HloInstruction* body_iter_next =
|
||||||
|
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
|
||||||
|
body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateTuple({body_iter_next, body_data_next}));
|
||||||
|
HloComputation* body_computation =
|
||||||
|
module->AddEmbeddedComputation(body_builder.Build());
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
HloInstruction* iter = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, s32_, "param_iter"));
|
||||||
|
HloInstruction* data = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
|
||||||
|
HloInstruction* negate = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
|
||||||
|
HloInstruction* tuple =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
|
||||||
|
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||||
|
t_s32_f32v10_color1, cond_computation, body_computation, tuple));
|
||||||
|
HloInstruction* while_data = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
f32vec10_, HloOpcode::kAdd, while_data, data));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
// Set only one preset assignment for while data and its aliases.
|
||||||
|
auto preset_assignments = absl::make_unique<PresetAssignments>();
|
||||||
|
preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
|
||||||
|
preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40});
|
||||||
|
preset_assignments->add_chunk({cond_param, {1}},
|
||||||
|
{/*offset=*/100, /*size=*/40});
|
||||||
|
preset_assignments->add_chunk({body_param, {1}},
|
||||||
|
{/*offset=*/100, /*size=*/40});
|
||||||
|
preset_assignments->add_chunk({body_data_next, {}},
|
||||||
|
{/*offset=*/100, /*size=*/40});
|
||||||
|
preset_assignments->add_size(/*memory_space=*/1, /*size=*/140);
|
||||||
|
|
||||||
|
auto buffers = RunBufferAssignmentWithPresetAssignments(
|
||||||
|
module.get(), std::move(preset_assignments));
|
||||||
|
|
||||||
|
// All assigned buffers are aliased so they should have the same offset and
|
||||||
|
// size.
|
||||||
|
const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
|
||||||
|
EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
|
||||||
|
for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
|
||||||
|
EXPECT_EQ(value_and_offsetsize.second.offset, 100);
|
||||||
|
EXPECT_EQ(value_and_offsetsize.second.size, 40);
|
||||||
|
EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
|
TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
|
||||||
// This is similar to the Basic test, with the difference that (sub) is
|
// This is similar to the Basic test, with the difference that (sub) is
|
||||||
// another user of (mul)'s result, so (mul)'s buffer cannot be reused for
|
// another user of (mul)'s result, so (mul)'s buffer cannot be reused for
|
||||||
|
Loading…
Reference in New Issue
Block a user