[XLA] Perform the check for cross-computation buffers in alternate mem earlier.

Previously, we were performing if a definition and its use is in the same
computation only after the no-copy allocation failed. This can create illegal
references to instructions across the computation boundary. The fix is
admittedly a bit restrictive but I will investigate approaches to enable better
use of alternate memory for while loops.

PiperOrigin-RevId: 277200085
Change-Id: I4be71188a5da75d71d16353241b5622890a80c0d
This commit is contained in:
Berkin Ilbeyi 2019-10-28 20:32:14 -07:00 committed by TensorFlower Gardener
parent 941bec1bbf
commit c7a4062cc2
2 changed files with 134 additions and 9 deletions

View File

@ -306,15 +306,10 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
}
}
// First try keeping the allocation entirely in the alternate memory.
if (!definition_requires_buffer_in_default_mem &&
!use_requires_buffer_in_default_mem &&
TryAllocatingInAlternateMemoryNoCopy(
start_time, end_time, last_use_time, defining_position, use,
alternate_mem_interval, non_bitcast_operand, allocations)) {
return true;
}
// TODO(berkin): This is curently overly restrictive and will fail using
// alternate memory for any buffer that might leak into a different
// computation (e.g., while body). Enable more usage of alternate memory
// across computations.
if (defining_position.instruction->parent() != use.instruction->parent() ||
(!use.instruction->called_computations().empty() &&
use.instruction->opcode() != HloOpcode::kFusion)) {
@ -324,6 +319,15 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
return false;
}
// First try keeping the allocation entirely in the alternate memory.
if (!definition_requires_buffer_in_default_mem &&
!use_requires_buffer_in_default_mem &&
TryAllocatingInAlternateMemoryNoCopy(
start_time, end_time, last_use_time, defining_position, use,
alternate_mem_interval, non_bitcast_operand, allocations)) {
return true;
}
MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
if (!allocations->empty()) {
prev_allocation = allocations->back().get();

View File

@ -1053,6 +1053,127 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
AssignMemorySpace(module.get(), -1, 5);
}
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
// This test reproduces the failure in b/143288178. Given a graph like the
// following:
//
// ... = foo(a)
// tuple = tuple((..., a)
// ... = while(tuple) {
// p = param(0)
// a1 = get-tuple-element(p), index=n-1
// ...
// ROOT tuple((..., a1))
// }
//
// If a copy to alternate memory is inserted before foo, and if the size of
// the while body is less than max prefetch interval so that the copy-done is
// kept in the alternate memory, then we end up refering to the copy-done in
// the root instruction of the while loop body. I.e.,
//
// cs = copy-start(a)
// ...
// cd = copy-done(cs)
// ... = foo(cd)
// tuple = tuple((..., cd)
// ... = while(tuple) {
// p = param(0)
// a1 = get-tuple-element(p), index=n-1
// ...
// ROOT tuple((..., cd)) <-- Error: cd belongs to outside computation.
// }
//
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
Shape tuple_shape =
ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape});
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_limit = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_limit, ComparisonDirection::kLt));
HloComputation* cond_computation =
module->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloInstruction* body_iter = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
HloInstruction* body_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
HloInstruction* body_iter_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
HloInstruction* body_iter_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
HloInstruction* body_data2 = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2));
HloInstruction* body_out = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_data, body_iter_next, body_data2}));
HloComputation* body_computation =
module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param_data"));
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
HloInstruction* data2 = builder.AddInstruction(
HloInstruction::CreateParameter(2, scalar_shape, "param_data2"));
HloInstruction* negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
HloInstruction* negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
HloInstruction* negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
HloInstruction* negate3 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
HloInstruction* negate4 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
HloInstruction* negate5 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
HloInstruction* negate6 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
HloInstruction* negate7 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kSubtract, iter, data2));
HloInstruction* tuple = builder.AddInstruction(
HloInstruction::CreateTuple({negate7, iter, data2}));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloInstruction* while_data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_op, 0));
HloInstruction* root =
builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_limit, cond_lt});
schedule.set_sequence(body_computation,
{body_param, body_iter, body_data, body_iter_increment,
body_iter_next, body_data2, body_out});
schedule.set_sequence(
entry_computation,
{iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5,
negate6, negate7, sub, tuple, while_op, while_data, root});
TF_CHECK_OK(module->set_schedule(schedule));
// Set a large max prefetch interval so that the buffer can be kept in
// alternate memory.
AssignMemorySpace(module.get(), -1, 20);
}
TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
// This situation was encountered in vss, where there is a mismatch in the
// memory space in preset assignments and the output graph.