[XLA] More robust handling for input/output aliasing in mem space assignment.
In this CL, we substitute HloBuffer-based allocation decisions with HloValue-based ones so that we don't unnecessarily couple allocation decisions for values that have input/output aliasing relationships. We introduce required memory assignments for inputs and outputs which prohibit prefetching and can force evictions to ensure inputs/outputs are properly assigned in the default memory space. In a future CL, I will also add support for inputs/outputs that are pre-set to be in the alternate (fast) memory space. PiperOrigin-RevId: 273337348
This commit is contained in:
parent
bb3ac73ba2
commit
4081350d8e
@ -53,24 +53,24 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
<< ", min prefetch interval = " << min_prefetch_interval_
|
||||
<< ", max prefetch interval = " << max_prefetch_interval_;
|
||||
|
||||
AddInputAndOutputRequiredAssignments();
|
||||
|
||||
for (auto& interval : sorted_buffer_intervals) {
|
||||
if (!interval.need_allocation) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if we have already allocated for this buffer.
|
||||
const HloBuffer& buffer =
|
||||
alias_analysis_.GetBufferContainingValue(*interval.buffer);
|
||||
if (allocation_map_->contains(&buffer)) {
|
||||
if (allocation_map_->contains(interval.buffer)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm.
|
||||
// TODO(berkin): Because tuples are cheap to place in the alternate memory
|
||||
// (they are just pointers) we don't need to use prefetch/evict logic.
|
||||
if (buffer.values()[0]->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping buffer " << buffer.ToString()
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just
|
||||
// pointers) we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
continue;
|
||||
}
|
||||
@ -89,9 +89,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
}
|
||||
}
|
||||
|
||||
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
|
||||
&(*allocation_map_)[&buffer];
|
||||
|
||||
// At this point, none of the colocated buffers contain any phi buffers.
|
||||
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
||||
if (keep_in_default_memory) {
|
||||
@ -99,6 +96,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
}
|
||||
const HloValue* value = colocated_interval->buffer;
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
|
||||
&(*allocation_map_)[value];
|
||||
int64 definition_time =
|
||||
instruction_schedule.at(value->defining_instruction());
|
||||
// Sort the uses by the use time.
|
||||
@ -141,7 +140,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
for (const auto& alloc_pair : *allocation_map_) {
|
||||
VLOG(3) << "Allocation for " << alloc_pair.first->ToString();
|
||||
VLOG(3) << "Allocation for " << alloc_pair.first->ToShortString();
|
||||
for (const auto& alloc : alloc_pair.second) {
|
||||
std::string addr_str = ": default";
|
||||
if (alloc->memory_space() == MemorySpace::kAlternate) {
|
||||
@ -157,6 +156,52 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
return result_;
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
|
||||
// Go through the parameters and outputs and pin them to default memory by
|
||||
// adding a required assignment.
|
||||
// TODO(berkin): If these values are already marked alternate memory, use
|
||||
// those instead.
|
||||
const HloDataflowAnalysis& dataflow_analysis =
|
||||
alias_analysis_.dataflow_analysis();
|
||||
const HloModule& module = dataflow_analysis.module();
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
HloComputation* entry_computation = module.entry_computation();
|
||||
for (HloInstruction* parameter_instruction :
|
||||
entry_computation->parameter_instructions()) {
|
||||
int64 parameter_instruction_time =
|
||||
instruction_schedule.at(parameter_instruction);
|
||||
ShapeUtil::ForEachSubshape(
|
||||
parameter_instruction->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
for (const HloValue* value :
|
||||
dataflow_analysis.GetValueSet(parameter_instruction, index)
|
||||
.values()) {
|
||||
VLOG(3) << "Adding required assignment for parameter value = "
|
||||
<< value->ToShortString()
|
||||
<< " time = " << parameter_instruction_time;
|
||||
required_assignments_[value].push_back(
|
||||
{/*memory_space=*/MemorySpace::kDefault,
|
||||
/*time=*/parameter_instruction_time});
|
||||
}
|
||||
});
|
||||
}
|
||||
HloInstruction* root_instruction = entry_computation->root_instruction();
|
||||
int64 root_instruction_time = instruction_schedule.at(root_instruction);
|
||||
ShapeUtil::ForEachSubshape(
|
||||
root_instruction->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
for (const HloValue* value :
|
||||
dataflow_analysis.GetValueSet(root_instruction, index).values()) {
|
||||
VLOG(3) << "Adding required assignment for output value = "
|
||||
<< value->ToShortString()
|
||||
<< " time = " << root_instruction_time;
|
||||
required_assignments_[value].push_back(
|
||||
{/*memory_space=*/MemorySpace::kDefault,
|
||||
/*time=*/root_instruction_time});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::CommitPendingChunks() {
|
||||
for (auto interval_and_chunk : pending_chunks_) {
|
||||
VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
|
||||
@ -214,8 +259,37 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||
: "");
|
||||
CHECK_LE(start_time, end_time);
|
||||
|
||||
// There could be a requirement to pin this buffer to default memory either at
|
||||
// the definition site (e.g., parameters) or at the use site (e.g., outputs).
|
||||
// If there is a definition requirement, then we're allowed to prefetch, but
|
||||
// if it's a use requirement, we cannot prefetch the buffer. If the use
|
||||
// expects the buffer to be in default memory, we cannot prefetch it because
|
||||
// if we did, it would be in alternate memory instead.
|
||||
bool definition_requires_buffer_in_default_mem = false;
|
||||
bool use_requires_buffer_in_default_mem = false;
|
||||
auto required_assignment_it = required_assignments_.find(buffer);
|
||||
if (required_assignment_it != required_assignments_.end()) {
|
||||
for (const RequiredMemoryAssignment& required_assignment :
|
||||
required_assignment_it->second) {
|
||||
VLOG(3) << "Required assignment at time = " << required_assignment.time;
|
||||
// TODO(berkin): Handle memory requirements for alternate memory space.
|
||||
if (required_assignment.memory_space == MemorySpace::kDefault) {
|
||||
if (required_assignment.time == start_time) {
|
||||
definition_requires_buffer_in_default_mem = true;
|
||||
VLOG(3) << "Definition requires buffer in default memory.";
|
||||
}
|
||||
if (required_assignment.time == end_time) {
|
||||
use_requires_buffer_in_default_mem = true;
|
||||
VLOG(3) << "Use requires buffer in default memory.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First try keeping the allocation entirely in the alternate memory.
|
||||
if (TryAllocatingInAlternateMemoryNoCopy(
|
||||
if (!definition_requires_buffer_in_default_mem &&
|
||||
!use_requires_buffer_in_default_mem &&
|
||||
TryAllocatingInAlternateMemoryNoCopy(
|
||||
start_time, end_time, last_use_time, defining_position, use,
|
||||
alternate_mem_interval, non_bitcast_operand, allocations)) {
|
||||
return true;
|
||||
@ -300,6 +374,15 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||
kDummyChunk, start_time, end_time));
|
||||
}
|
||||
|
||||
// If the use requires the buffer to be in default memory, don't try to
|
||||
// prefetch.
|
||||
if (use_requires_buffer_in_default_mem) {
|
||||
VLOG(4)
|
||||
<< "Not trying to prefetch because use requires buffer in default mem.";
|
||||
allocations->back()->AddUse(use);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try partially placing the buffer in the alternate space. The time that is
|
||||
// overlapped will be used to asynchronously copy the buffer from the
|
||||
// default memory to the alternate memory.
|
||||
|
@ -203,7 +203,7 @@ class MemorySpaceAssignment {
|
||||
|
||||
using AllocationSequence = std::list<std::unique_ptr<Allocation>>;
|
||||
using AllocationMap =
|
||||
absl::flat_hash_map<const HloBuffer*, AllocationSequence>;
|
||||
absl::flat_hash_map<const HloValue*, AllocationSequence>;
|
||||
|
||||
// Runs the MemorySpaceAssignment pass. alternate_memory_space is the
|
||||
// architecture-specific integer value that describes the alternate memory.
|
||||
@ -272,6 +272,15 @@ class MemorySpaceAssignment {
|
||||
std::vector<HloPosition> pending_positions_in_alternate_mem_;
|
||||
};
|
||||
|
||||
// This struct contains mandatory memory assignments at a given time. E.g., an
|
||||
// input's required memory assignment time would correspond to the definition
|
||||
// time of the parameter instruction, and an output's time would correspnd to
|
||||
// the time of last use.
|
||||
struct RequiredMemoryAssignment {
|
||||
MemorySpaceAssignment::MemorySpace memory_space;
|
||||
int64 time;
|
||||
};
|
||||
|
||||
// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
|
||||
// maximum size.
|
||||
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
@ -320,6 +329,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
HloInstruction* non_bitcast_operand,
|
||||
MemorySpaceAssignment::AllocationSequence* allocations);
|
||||
|
||||
// Adds input and outputs as required assignments.
|
||||
void AddInputAndOutputRequiredAssignments();
|
||||
|
||||
// Given a buffer interval, returns the colocated intervals. Unlike the
|
||||
// similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
|
||||
// returns the colocated intervals sorted by scheduled time.
|
||||
@ -373,6 +385,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
int64 max_outstanding_async_copies_;
|
||||
std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
|
||||
std::vector<std::pair<int64, int64>> pending_async_copies_;
|
||||
// This map contains required memory assignments for HloValues (e.g., input
|
||||
// and outputs).
|
||||
absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
|
||||
required_assignments_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -1168,5 +1168,60 @@ TEST_F(MemorySpaceAssignmentTest, TupleInput) {
|
||||
AssignMemorySpace(module.get());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||
HloInstruction* p = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p"));
|
||||
HloInstruction* p0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, p, 0));
|
||||
HloInstruction* negate0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
|
||||
HloInstruction* negate1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
|
||||
HloInstruction* negate2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
|
||||
HloInstruction* negate3 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
|
||||
HloInstruction* negate4 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
|
||||
HloInstruction* negate5 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
|
||||
HloInstruction* negate6 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
|
||||
HloInstruction* p1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, p, 1));
|
||||
HloInstruction* add = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
|
||||
HloInstruction* negate7 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add));
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({p0, add}));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(
|
||||
computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
|
||||
negate6, p1, add, negate7, tuple});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
// Make input {0} alias with output {0} and input {1} alias with output {1}.
|
||||
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
|
||||
{0}, 0, {0}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
|
||||
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
|
||||
{1}, 0, {1}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
// Make sure the input is in the default memory space.
|
||||
EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(),
|
||||
kDefaultMemorySpace);
|
||||
EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(),
|
||||
kDefaultMemorySpace);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user