[XLA] Perform the check for cross-computation buffers in alternate mem earlier.
Previously, we were performing if a definition and its use is in the same computation only after the no-copy allocation failed. This can create illegal references to instructions across the computation boundary. The fix is admittedly a bit restrictive but I will investigate approaches to enable better use of alternate memory for while loops. PiperOrigin-RevId: 277200085 Change-Id: I4be71188a5da75d71d16353241b5622890a80c0d
This commit is contained in:
parent
941bec1bbf
commit
c7a4062cc2
@ -306,15 +306,10 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||
}
|
||||
}
|
||||
|
||||
// First try keeping the allocation entirely in the alternate memory.
|
||||
if (!definition_requires_buffer_in_default_mem &&
|
||||
!use_requires_buffer_in_default_mem &&
|
||||
TryAllocatingInAlternateMemoryNoCopy(
|
||||
start_time, end_time, last_use_time, defining_position, use,
|
||||
alternate_mem_interval, non_bitcast_operand, allocations)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO(berkin): This is curently overly restrictive and will fail using
|
||||
// alternate memory for any buffer that might leak into a different
|
||||
// computation (e.g., while body). Enable more usage of alternate memory
|
||||
// across computations.
|
||||
if (defining_position.instruction->parent() != use.instruction->parent() ||
|
||||
(!use.instruction->called_computations().empty() &&
|
||||
use.instruction->opcode() != HloOpcode::kFusion)) {
|
||||
@ -324,6 +319,15 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||
return false;
|
||||
}
|
||||
|
||||
// First try keeping the allocation entirely in the alternate memory.
|
||||
if (!definition_requires_buffer_in_default_mem &&
|
||||
!use_requires_buffer_in_default_mem &&
|
||||
TryAllocatingInAlternateMemoryNoCopy(
|
||||
start_time, end_time, last_use_time, defining_position, use,
|
||||
alternate_mem_interval, non_bitcast_operand, allocations)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
|
||||
if (!allocations->empty()) {
|
||||
prev_allocation = allocations->back().get();
|
||||
|
@ -1053,6 +1053,127 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
|
||||
AssignMemorySpace(module.get(), -1, 5);
|
||||
}
|
||||
|
||||
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
|
||||
// This test reproduces the failure in b/143288178. Given a graph like the
|
||||
// following:
|
||||
//
|
||||
// ... = foo(a)
|
||||
// tuple = tuple((..., a)
|
||||
// ... = while(tuple) {
|
||||
// p = param(0)
|
||||
// a1 = get-tuple-element(p), index=n-1
|
||||
// ...
|
||||
// ROOT tuple((..., a1))
|
||||
// }
|
||||
//
|
||||
// If a copy to alternate memory is inserted before foo, and if the size of
|
||||
// the while body is less than max prefetch interval so that the copy-done is
|
||||
// kept in the alternate memory, then we end up refering to the copy-done in
|
||||
// the root instruction of the while loop body. I.e.,
|
||||
//
|
||||
// cs = copy-start(a)
|
||||
// ...
|
||||
// cd = copy-done(cs)
|
||||
// ... = foo(cd)
|
||||
// tuple = tuple((..., cd)
|
||||
// ... = while(tuple) {
|
||||
// p = param(0)
|
||||
// a1 = get-tuple-element(p), index=n-1
|
||||
// ...
|
||||
// ROOT tuple((..., cd)) <-- Error: cd belongs to outside computation.
|
||||
// }
|
||||
//
|
||||
auto module = CreateNewVerifiedModule();
|
||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape});
|
||||
|
||||
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
|
||||
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
HloInstruction* cond_limit = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||
cond_limit, ComparisonDirection::kLt));
|
||||
HloComputation* cond_computation =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto body_builder = HloComputation::Builder("WhileBody");
|
||||
HloInstruction* body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
|
||||
HloInstruction* body_iter = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
|
||||
HloInstruction* body_data = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
|
||||
HloInstruction* body_iter_increment = body_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
|
||||
HloInstruction* body_iter_next =
|
||||
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
|
||||
HloInstruction* body_data2 = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2));
|
||||
HloInstruction* body_out = body_builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({body_data, body_iter_next, body_data2}));
|
||||
HloComputation* body_computation =
|
||||
module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param_data"));
|
||||
HloInstruction* iter = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
|
||||
HloInstruction* data2 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, scalar_shape, "param_data2"));
|
||||
HloInstruction* negate0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
|
||||
HloInstruction* negate1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
|
||||
HloInstruction* negate2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
|
||||
HloInstruction* negate3 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
|
||||
HloInstruction* negate4 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
|
||||
HloInstruction* negate5 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
|
||||
HloInstruction* negate6 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
|
||||
HloInstruction* negate7 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
|
||||
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kSubtract, iter, data2));
|
||||
HloInstruction* tuple = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({negate7, iter, data2}));
|
||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
tuple_shape, cond_computation, body_computation, tuple));
|
||||
HloInstruction* while_data = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, while_op, 0));
|
||||
HloInstruction* root =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
|
||||
HloComputation* entry_computation =
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(cond_computation,
|
||||
{cond_param, cond_iter, cond_limit, cond_lt});
|
||||
schedule.set_sequence(body_computation,
|
||||
{body_param, body_iter, body_data, body_iter_increment,
|
||||
body_iter_next, body_data2, body_out});
|
||||
schedule.set_sequence(
|
||||
entry_computation,
|
||||
{iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5,
|
||||
negate6, negate7, sub, tuple, while_op, while_data, root});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
// Set a large max prefetch interval so that the buffer can be kept in
|
||||
// alternate memory.
|
||||
AssignMemorySpace(module.get(), -1, 20);
|
||||
}
|
||||
|
||||
TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
|
||||
// This situation was encountered in vss, where there is a mismatch in the
|
||||
// memory space in preset assignments and the output graph.
|
||||
|
Loading…
Reference in New Issue
Block a user