[XLA] Use alias analysis to find all aliased required assmts in mem space assmt.
We were previously using dataflow analysis and that can be incorrect since it didn't know about aliased buffers. Due to that, we might end up assigning module inputs and outputs to the alternate memory space even though that is not allowed. PiperOrigin-RevId: 282050905 Change-Id: I8a257e15a00f2fb1a0155ac9eaa98a582a05cadc
This commit is contained in:
parent
5b38e0e7b2
commit
b6d720fd93
@ -370,9 +370,7 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
|
||||
// adding a required assignment.
|
||||
// TODO(berkin): If these values are already marked alternate memory, use
|
||||
// those instead.
|
||||
const HloDataflowAnalysis& dataflow_analysis =
|
||||
alias_analysis_.dataflow_analysis();
|
||||
const HloModule& module = dataflow_analysis.module();
|
||||
const HloModule& module = alias_analysis_.dataflow_analysis().module();
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
HloComputation* entry_computation = module.entry_computation();
|
||||
for (HloInstruction* parameter_instruction :
|
||||
@ -382,15 +380,16 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
|
||||
ShapeUtil::ForEachSubshape(
|
||||
parameter_instruction->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
for (const HloValue* value :
|
||||
dataflow_analysis.GetValueSet(parameter_instruction, index)
|
||||
.values()) {
|
||||
VLOG(3) << "Adding required assignment for parameter value = "
|
||||
<< value->ToShortString()
|
||||
<< " time = " << parameter_instruction_time;
|
||||
required_assignments_[value].push_back(
|
||||
{/*memory_space=*/MemorySpace::kDefault,
|
||||
/*time=*/parameter_instruction_time});
|
||||
for (const HloBuffer* buffer :
|
||||
alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
|
||||
for (const HloValue* value : buffer->values()) {
|
||||
VLOG(3) << "Adding required assignment for parameter value = "
|
||||
<< value->ToShortString()
|
||||
<< " time = " << parameter_instruction_time;
|
||||
required_assignments_[value].push_back(
|
||||
{/*memory_space=*/MemorySpace::kDefault,
|
||||
/*time=*/parameter_instruction_time});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -399,14 +398,16 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
|
||||
ShapeUtil::ForEachSubshape(
|
||||
root_instruction->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
for (const HloValue* value :
|
||||
dataflow_analysis.GetValueSet(root_instruction, index).values()) {
|
||||
VLOG(3) << "Adding required assignment for output value = "
|
||||
<< value->ToShortString()
|
||||
<< " time = " << root_instruction_time;
|
||||
required_assignments_[value].push_back(
|
||||
{/*memory_space=*/MemorySpace::kDefault,
|
||||
/*time=*/root_instruction_time});
|
||||
for (const HloBuffer* buffer :
|
||||
alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
|
||||
for (const HloValue* value : buffer->values()) {
|
||||
VLOG(3) << "Adding required assignment for output value = "
|
||||
<< value->ToShortString()
|
||||
<< " time = " << root_instruction_time;
|
||||
required_assignments_[value].push_back(
|
||||
{/*memory_space=*/MemorySpace::kDefault,
|
||||
/*time=*/root_instruction_time});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -1927,6 +1927,111 @@ TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
|
||||
EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
|
||||
Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
||||
Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
|
||||
Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1});
|
||||
auto module = CreateNewVerifiedModule("SimpleWhile");
|
||||
HloSchedule schedule(module.get());
|
||||
|
||||
// A simple compare-to-limit (x < 4) computation for a While.
|
||||
//
|
||||
// condition:
|
||||
// const4[s32] -----------------------------------\
|
||||
// \
|
||||
// param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
|
||||
//
|
||||
HloComputation* cond_computation;
|
||||
{
|
||||
auto builder = HloComputation::Builder("WhileCond");
|
||||
auto const4 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
|
||||
auto index = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
|
||||
auto compare = builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
|
||||
const4, ComparisonDirection::kLt));
|
||||
cond_computation = module->AddEmbeddedComputation(builder.Build());
|
||||
schedule.set_sequence(cond_computation, {const4, param, index, compare});
|
||||
}
|
||||
|
||||
// Builds a simple body computation for a While.
|
||||
//
|
||||
// body:
|
||||
// constv[f32[1]] --------------------------------------\
|
||||
// \
|
||||
// /--- get-tuple-elementv[1] --- addv ---\
|
||||
// param[(s32,f32[1])] ---| tuple
|
||||
// \--- get-tuple-elementc[0] --- addc ---/
|
||||
// /
|
||||
// const1[s32] -----------------------------------------/
|
||||
//
|
||||
HloComputation* body_computation;
|
||||
{
|
||||
auto builder = HloComputation::Builder("WhileBody");
|
||||
auto const1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
|
||||
auto constv = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f})));
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
|
||||
auto indexc = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
|
||||
auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
indexc->shape(), HloOpcode::kAdd, indexc, const1));
|
||||
auto indexv = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
|
||||
auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
constv->shape(), HloOpcode::kAdd, indexv, constv));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
|
||||
body_computation = module->AddEmbeddedComputation(builder.Build());
|
||||
schedule.set_sequence(body_computation, {const1, constv, param, indexc,
|
||||
addc, indexv, addv, tuple});
|
||||
}
|
||||
|
||||
// This tests a simple while loop where the parameters are aliased with the
|
||||
// output buffers.
|
||||
auto builder = HloComputation::Builder("SimpleWhile");
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, t_s32_f32v1, "param"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(s32, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32v1, param, 1));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
auto while0 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
t_s32_f32v1, cond_computation, body_computation, tuple));
|
||||
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||
/*max_prefetch_interval=*/50);
|
||||
|
||||
// Ensure all parameters and while are placed in default memory.
|
||||
Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
|
||||
F32, {4, 6},
|
||||
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||
kDefaultMemorySpace);
|
||||
Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout(
|
||||
xla::S32, {},
|
||||
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||
kDefaultMemorySpace);
|
||||
Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout(
|
||||
F32, {1},
|
||||
/*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||
kDefaultMemorySpace);
|
||||
Shape t_s32_f32v1_in_default_mem =
|
||||
ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem});
|
||||
EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
|
||||
EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
||||
MemorySpaceAssignmentTest,
|
||||
::testing::Values(false, true));
|
||||
|
Loading…
Reference in New Issue
Block a user