[XLA] Use dataflow analysis instead of alias analysis for setting preset assmts.

We were previously using alias analysis to find all aliased HloValues for a
given HloBuffer for the HloPosition, which led to allocating the same buffer
multiple times. Instead, use dataflow analysis to get only the unique HloValue.
This is because there is one-to-one correspondance between HloPosition and
HloValue, but multiple different HloPositions can map to the same HloBuffer in
case we have aliasing due to while loops.

PiperOrigin-RevId: 280318809
Change-Id: I9c674911de3cb8582c078d7accedf627507e12ca
This commit is contained in:
Berkin Ilbeyi 2019-11-13 17:35:30 -08:00 committed by TensorFlower Gardener
parent 12fda6ebfe
commit f95bd6ec17
3 changed files with 103 additions and 13 deletions

View File

@ -1323,22 +1323,24 @@ Status BufferAssigner::AssignPresetBuffers(
}
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
const HloDataflowAnalysis& dataflow_analysis =
alias_analysis.dataflow_analysis();
for (auto& position_and_chunk : preset_assignments_->chunks()) {
const HloPosition& position = position_and_chunk.first;
const HloBuffer& buffer =
alias_analysis.GetUniqueBufferAt(position.instruction, position.index);
VLOG(3) << "Preset allocation for buffer: " << buffer;
const HloValue& value = dataflow_analysis.GetUniqueValueAt(
position.instruction, position.index);
VLOG(3) << "Preset allocation for value: " << value.ToShortString();
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
auto preset_allocations_iter = preset_allocations.find(buffer.color());
auto preset_allocations_iter = preset_allocations.find(value.color());
CHECK(preset_allocations_iter != preset_allocations.end())
<< "No preset buffer allocation for color " << buffer.color()
<< "No preset value allocation for color " << value.color()
<< " found.";
preset_allocations_iter->second->AddAssignment(buffer.GetUniqueValue(),
chunk.offset, chunk.size);
// Ensure that there is at most one preset allocation for each buffer.
CHECK_EQ(assigned_buffers->count(&buffer), 0);
assigned_buffers->emplace(&buffer);
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
chunk.size);
const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
assigned_buffers->insert(&buffer);
}
// Upon consumption of the preset assignments, delete it so that if this

View File

@ -566,10 +566,10 @@ class BufferAssigner {
static Colorer DefaultColorer() {
return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
HloInstruction* defining_instruction = value->defining_instruction();
if (defining_instruction->shape().has_layout()) {
const HloPosition& defining_position = value->defining_position();
if (defining_position.shape().has_layout()) {
value->set_color(BufferValue::Color(
defining_instruction->shape().layout().memory_space()));
defining_position.shape().layout().memory_space()));
} else {
value->set_color(BufferValue::Color(0));
}

View File

@ -769,6 +769,94 @@ TEST_F(BufferAssignmentTest, PresetAssignments) {
GetAssignedOutputAllocation(*buffers, sub);
}
TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
// Tests preset assignments when there is no 1-to-1 corrspondance between
// HloValue and HloBuffer (i.e., a while loop).
auto module = CreateNewVerifiedModule();
Shape f32vec10_color1 =
ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
Shape t_s32_f32v10_color1 =
ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
HloInstruction* cond_limit = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_limit, ComparisonDirection::kLt));
HloComputation* cond_computation =
module->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
HloInstruction* body_iter = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
HloInstruction* body_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
HloInstruction* body_data_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
HloInstruction* body_data_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
HloInstruction* body_iter_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
HloInstruction* body_iter_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
body_builder.AddInstruction(
HloInstruction::CreateTuple({body_iter_next, body_data_next}));
HloComputation* body_computation =
module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(0, s32_, "param_iter"));
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
HloInstruction* negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
t_s32_f32v10_color1, cond_computation, body_computation, tuple));
HloInstruction* while_data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
builder.AddInstruction(HloInstruction::CreateBinary(
f32vec10_, HloOpcode::kAdd, while_data, data));
module->AddEntryComputation(builder.Build());
// Set only one preset assignment for while data and its aliases.
auto preset_assignments = absl::make_unique<PresetAssignments>();
preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({cond_param, {1}},
{/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({body_param, {1}},
{/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({body_data_next, {}},
{/*offset=*/100, /*size=*/40});
preset_assignments->add_size(/*memory_space=*/1, /*size=*/140);
auto buffers = RunBufferAssignmentWithPresetAssignments(
module.get(), std::move(preset_assignments));
// All assigned buffers are aliased so they should have the same offset and
// size.
const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
EXPECT_EQ(value_and_offsetsize.second.offset, 100);
EXPECT_EQ(value_and_offsetsize.second.size, 40);
EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
}
}
TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
// This is similar to the Basic test, with the difference that (sub) is
// another user of (mul)'s result, so (mul)'s buffer cannot be reused for