[XLA] More robust handling for input/output aliasing in mem space assignment.

In this CL, we substitute HloBuffer-based allocation decisions with
HloValue-based ones so that we don't unnecessarily couple allocation decisions
for values that have input/output aliasing relationships. We introduce required
memory assignments for inputs and outputs which prohibit prefetching and can
force evictions to ensure inputs/outputs are properly assigned in the default
memory space. In a future CL, I will also add support for inputs/outputs that
are pre-set to be in the alternate (fast) memory space.

PiperOrigin-RevId: 273337348
This commit is contained in:
Berkin Ilbeyi 2019-10-07 11:28:15 -07:00 committed by TensorFlower Gardener
parent bb3ac73ba2
commit 4081350d8e
3 changed files with 168 additions and 14 deletions

View File

@ -53,24 +53,24 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
<< ", min prefetch interval = " << min_prefetch_interval_
<< ", max prefetch interval = " << max_prefetch_interval_;
AddInputAndOutputRequiredAssignments();
for (auto& interval : sorted_buffer_intervals) {
if (!interval.need_allocation) {
continue;
}
// Skip if we have already allocated for this buffer.
const HloBuffer& buffer =
alias_analysis_.GetBufferContainingValue(*interval.buffer);
if (allocation_map_->contains(&buffer)) {
if (allocation_map_->contains(interval.buffer)) {
continue;
}
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm.
// TODO(berkin): Because tuples are cheap to place in the alternate memory
// (they are just pointers) we don't need to use prefetch/evict logic.
if (buffer.values()[0]->shape().IsTuple()) {
VLOG(4) << "Keeping buffer " << buffer.ToString()
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just
// pointers) we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a tuple.";
continue;
}
@ -89,9 +89,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
}
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
&(*allocation_map_)[&buffer];
// At this point, none of the colocated buffers contain any phi buffers.
for (const BufferInterval* colocated_interval : colocated_intervals) {
if (keep_in_default_memory) {
@ -99,6 +96,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
const HloValue* value = colocated_interval->buffer;
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
&(*allocation_map_)[value];
int64 definition_time =
instruction_schedule.at(value->defining_instruction());
// Sort the uses by the use time.
@ -141,7 +140,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
if (VLOG_IS_ON(3)) {
for (const auto& alloc_pair : *allocation_map_) {
VLOG(3) << "Allocation for " << alloc_pair.first->ToString();
VLOG(3) << "Allocation for " << alloc_pair.first->ToShortString();
for (const auto& alloc : alloc_pair.second) {
std::string addr_str = ": default";
if (alloc->memory_space() == MemorySpace::kAlternate) {
@ -157,6 +156,52 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
return result_;
}
void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
// Go through the parameters and outputs and pin them to default memory by
// adding a required assignment.
// TODO(berkin): If these values are already marked alternate memory, use
// those instead.
const HloDataflowAnalysis& dataflow_analysis =
alias_analysis_.dataflow_analysis();
const HloModule& module = dataflow_analysis.module();
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
HloComputation* entry_computation = module.entry_computation();
for (HloInstruction* parameter_instruction :
entry_computation->parameter_instructions()) {
int64 parameter_instruction_time =
instruction_schedule.at(parameter_instruction);
ShapeUtil::ForEachSubshape(
parameter_instruction->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
for (const HloValue* value :
dataflow_analysis.GetValueSet(parameter_instruction, index)
.values()) {
VLOG(3) << "Adding required assignment for parameter value = "
<< value->ToShortString()
<< " time = " << parameter_instruction_time;
required_assignments_[value].push_back(
{/*memory_space=*/MemorySpace::kDefault,
/*time=*/parameter_instruction_time});
}
});
}
HloInstruction* root_instruction = entry_computation->root_instruction();
int64 root_instruction_time = instruction_schedule.at(root_instruction);
ShapeUtil::ForEachSubshape(
root_instruction->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
for (const HloValue* value :
dataflow_analysis.GetValueSet(root_instruction, index).values()) {
VLOG(3) << "Adding required assignment for output value = "
<< value->ToShortString()
<< " time = " << root_instruction_time;
required_assignments_[value].push_back(
{/*memory_space=*/MemorySpace::kDefault,
/*time=*/root_instruction_time});
}
});
}
void AlternateMemoryBestFitHeap::CommitPendingChunks() {
for (auto interval_and_chunk : pending_chunks_) {
VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
@ -214,8 +259,37 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
: "");
CHECK_LE(start_time, end_time);
// There could be a requirement to pin this buffer to default memory either at
// the definition site (e.g., parameters) or at the use site (e.g., outputs).
// If there is a definition requirement, then we're allowed to prefetch, but
// if it's a use requirement, we cannot prefetch the buffer. If the use
// expects the buffer to be in default memory, we cannot prefetch it because
// if we did, it would be in alternate memory instead.
bool definition_requires_buffer_in_default_mem = false;
bool use_requires_buffer_in_default_mem = false;
auto required_assignment_it = required_assignments_.find(buffer);
if (required_assignment_it != required_assignments_.end()) {
for (const RequiredMemoryAssignment& required_assignment :
required_assignment_it->second) {
VLOG(3) << "Required assignment at time = " << required_assignment.time;
// TODO(berkin): Handle memory requirements for alternate memory space.
if (required_assignment.memory_space == MemorySpace::kDefault) {
if (required_assignment.time == start_time) {
definition_requires_buffer_in_default_mem = true;
VLOG(3) << "Definition requires buffer in default memory.";
}
if (required_assignment.time == end_time) {
use_requires_buffer_in_default_mem = true;
VLOG(3) << "Use requires buffer in default memory.";
}
}
}
}
// First try keeping the allocation entirely in the alternate memory.
if (TryAllocatingInAlternateMemoryNoCopy(
if (!definition_requires_buffer_in_default_mem &&
!use_requires_buffer_in_default_mem &&
TryAllocatingInAlternateMemoryNoCopy(
start_time, end_time, last_use_time, defining_position, use,
alternate_mem_interval, non_bitcast_operand, allocations)) {
return true;
@ -300,6 +374,15 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
kDummyChunk, start_time, end_time));
}
// If the use requires the buffer to be in default memory, don't try to
// prefetch.
if (use_requires_buffer_in_default_mem) {
VLOG(4)
<< "Not trying to prefetch because use requires buffer in default mem.";
allocations->back()->AddUse(use);
return true;
}
// Try partially placing the buffer in the alternate space. The time that is
// overlapped will be used to asynchronously copy the buffer from the
// default memory to the alternate memory.

View File

@ -203,7 +203,7 @@ class MemorySpaceAssignment {
using AllocationSequence = std::list<std::unique_ptr<Allocation>>;
using AllocationMap =
absl::flat_hash_map<const HloBuffer*, AllocationSequence>;
absl::flat_hash_map<const HloValue*, AllocationSequence>;
// Runs the MemorySpaceAssignment pass. alternate_memory_space is the
// architecture-specific integer value that describes the alternate memory.
@ -272,6 +272,15 @@ class MemorySpaceAssignment {
std::vector<HloPosition> pending_positions_in_alternate_mem_;
};
// This struct contains mandatory memory assignments at a given time. E.g., an
// input's required memory assignment time would correspond to the definition
// time of the parameter instruction, and an output's time would correspnd to
// the time of last use.
struct RequiredMemoryAssignment {
MemorySpaceAssignment::MemorySpace memory_space;
int64 time;
};
// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
// maximum size.
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
@ -320,6 +329,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
HloInstruction* non_bitcast_operand,
MemorySpaceAssignment::AllocationSequence* allocations);
// Adds input and outputs as required assignments.
void AddInputAndOutputRequiredAssignments();
// Given a buffer interval, returns the colocated intervals. Unlike the
// similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
// returns the colocated intervals sorted by scheduled time.
@ -373,6 +385,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
int64 max_outstanding_async_copies_;
std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
std::vector<std::pair<int64, int64>> pending_async_copies_;
// This map contains required memory assignments for HloValues (e.g., input
// and outputs).
absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
required_assignments_;
};
} // namespace xla

View File

@ -1168,5 +1168,60 @@ TEST_F(MemorySpaceAssignmentTest, TupleInput) {
AssignMemorySpace(module.get());
}
TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
HloInstruction* p = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p"));
HloInstruction* p0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p, 0));
HloInstruction* negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
HloInstruction* negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
HloInstruction* negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
HloInstruction* negate3 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
HloInstruction* negate4 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
HloInstruction* negate5 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
HloInstruction* negate6 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
HloInstruction* p1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p, 1));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
HloInstruction* negate7 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({p0, add}));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(
computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
negate6, p1, add, negate7, tuple});
TF_CHECK_OK(module->set_schedule(schedule));
// Make input {0} alias with output {0} and input {1} alias with output {1}.
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
{0}, 0, {0}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
{1}, 0, {1}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
AssignMemorySpace(module.get());
// Make sure the input is in the default memory space.
EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(),
kDefaultMemorySpace);
EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(),
kDefaultMemorySpace);
}
} // namespace
} // namespace xla