[XLA] Use alias analysis to find all aliased required assmts in mem space assmt.

We were previously using dataflow analysis and that can be incorrect since it
didn't know about aliased buffers. Due to that, we might end up assigning module
inputs and outputs to the alternate memory space even though that is not
allowed.

PiperOrigin-RevId: 282050905
Change-Id: I8a257e15a00f2fb1a0155ac9eaa98a582a05cadc
This commit is contained in:
Berkin Ilbeyi 2019-11-22 15:09:37 -08:00 committed by TensorFlower Gardener
parent 5b38e0e7b2
commit b6d720fd93
2 changed files with 126 additions and 20 deletions

View File

@ -370,9 +370,7 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
// adding a required assignment.
// TODO(berkin): If these values are already marked alternate memory, use
// those instead.
const HloDataflowAnalysis& dataflow_analysis =
alias_analysis_.dataflow_analysis();
const HloModule& module = dataflow_analysis.module();
const HloModule& module = alias_analysis_.dataflow_analysis().module();
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
HloComputation* entry_computation = module.entry_computation();
for (HloInstruction* parameter_instruction :
@ -382,15 +380,16 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
ShapeUtil::ForEachSubshape(
parameter_instruction->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
for (const HloValue* value :
dataflow_analysis.GetValueSet(parameter_instruction, index)
.values()) {
VLOG(3) << "Adding required assignment for parameter value = "
<< value->ToShortString()
<< " time = " << parameter_instruction_time;
required_assignments_[value].push_back(
{/*memory_space=*/MemorySpace::kDefault,
/*time=*/parameter_instruction_time});
for (const HloBuffer* buffer :
alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
for (const HloValue* value : buffer->values()) {
VLOG(3) << "Adding required assignment for parameter value = "
<< value->ToShortString()
<< " time = " << parameter_instruction_time;
required_assignments_[value].push_back(
{/*memory_space=*/MemorySpace::kDefault,
/*time=*/parameter_instruction_time});
}
}
});
}
@ -399,14 +398,16 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
ShapeUtil::ForEachSubshape(
root_instruction->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
for (const HloValue* value :
dataflow_analysis.GetValueSet(root_instruction, index).values()) {
VLOG(3) << "Adding required assignment for output value = "
<< value->ToShortString()
<< " time = " << root_instruction_time;
required_assignments_[value].push_back(
{/*memory_space=*/MemorySpace::kDefault,
/*time=*/root_instruction_time});
for (const HloBuffer* buffer :
alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
for (const HloValue* value : buffer->values()) {
VLOG(3) << "Adding required assignment for output value = "
<< value->ToShortString()
<< " time = " << root_instruction_time;
required_assignments_[value].push_back(
{/*memory_space=*/MemorySpace::kDefault,
/*time=*/root_instruction_time});
}
}
});
}

View File

@ -1927,6 +1927,111 @@ TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
}
TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1});
auto module = CreateNewVerifiedModule("SimpleWhile");
HloSchedule schedule(module.get());
// A simple compare-to-limit (x < 4) computation for a While.
//
// condition:
// const4[s32] -----------------------------------\
// \
// param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
//
HloComputation* cond_computation;
{
auto builder = HloComputation::Builder("WhileCond");
auto const4 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
auto index = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
auto compare = builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
const4, ComparisonDirection::kLt));
cond_computation = module->AddEmbeddedComputation(builder.Build());
schedule.set_sequence(cond_computation, {const4, param, index, compare});
}
// Builds a simple body computation for a While.
//
// body:
// constv[f32[1]] --------------------------------------\
// \
// /--- get-tuple-elementv[1] --- addv ---\
// param[(s32,f32[1])] ---| tuple
// \--- get-tuple-elementc[0] --- addc ---/
// /
// const1[s32] -----------------------------------------/
//
HloComputation* body_computation;
{
auto builder = HloComputation::Builder("WhileBody");
auto const1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
auto constv = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f})));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
auto indexc = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
indexc->shape(), HloOpcode::kAdd, indexc, const1));
auto indexv = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
constv->shape(), HloOpcode::kAdd, indexv, constv));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
body_computation = module->AddEmbeddedComputation(builder.Build());
schedule.set_sequence(body_computation, {const1, constv, param, indexc,
addc, indexv, addv, tuple});
}
// This tests a simple while loop where the parameters are aliased with the
// output buffers.
auto builder = HloComputation::Builder("SimpleWhile");
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v1, "param"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(s32, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32v1, param, 1));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
auto while0 = builder.AddInstruction(HloInstruction::CreateWhile(
t_s32_f32v1, cond_computation, body_computation, tuple));
HloComputation* computation = module->AddEntryComputation(builder.Build());
schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
/*max_prefetch_interval=*/50);
// Ensure all parameters and while are placed in default memory.
Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
F32, {4, 6},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kDefaultMemorySpace);
Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout(
xla::S32, {},
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kDefaultMemorySpace);
Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout(
F32, {1},
/*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kDefaultMemorySpace);
Shape t_s32_f32v1_in_default_mem =
ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem});
EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
}
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
MemorySpaceAssignmentTest,
::testing::Values(false, true));