[XLA] More tweaks and fixes to memory space assignment:
- When sorting the BufferIntervals, now we also correctly include aliased HloValues into the benefit computation. - To avoid double counting, we now skip over while and conditional HLOs when using cost analysis since we already look inside the called computations. - When trying to find a preferred alternate memory offset, we now find the latest use that can be contiguously allocated (determined by the max overlap ratio heuristic) instead of trying to find an allocation that is as long-living as possible. This should improve fragmentation slightly. PiperOrigin-RevId: 316777648 Change-Id: If9d40a79283b644db975ebe62b1bb6c545fea89d
This commit is contained in:
parent
a6945b9b0f
commit
fad7b3a33b
|
@ -31,6 +31,22 @@ const int kWhileExecutionCount = 5;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
|
||||||
|
MemorySpaceAssignmentCostAnalysis::Create(
|
||||||
|
const HloCostAnalysis& cost_analysis,
|
||||||
|
float async_copy_bandwidth_bytes_per_second,
|
||||||
|
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
|
||||||
|
HloLiveRange::Run(module.schedule(), *alias_analysis,
|
||||||
|
module.entry_computation()));
|
||||||
|
auto call_graph = CallGraph::Build(&module);
|
||||||
|
return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
|
||||||
|
cost_analysis, async_copy_bandwidth_bytes_per_second,
|
||||||
|
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
|
||||||
|
std::move(hlo_live_range), std::move(call_graph)));
|
||||||
|
}
|
||||||
|
|
||||||
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
||||||
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
|
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
|
||||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
||||||
|
@ -74,19 +90,32 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||||
/*operand_in_alternate_mem=*/{},
|
/*operand_in_alternate_mem=*/{},
|
||||||
/*output_in_alternate_mem=*/true),
|
/*output_in_alternate_mem=*/true),
|
||||||
cache);
|
cache);
|
||||||
for (const HloUse& use : interval.buffer->uses()) {
|
for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
|
||||||
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
|
interval.buffer->defining_position().instruction,
|
||||||
*use.instruction,
|
interval.buffer->defining_position().index)) {
|
||||||
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
|
for (const HloValue* value : buffer->values()) {
|
||||||
cache);
|
for (const HloUse& use : value->uses()) {
|
||||||
// If the benefit is positive (memory bound), add it to this buffer's
|
// We look inside the called computations of while and conditional, so
|
||||||
// benefit. If the benefit is negative (compute bound), calculate the
|
// don't use the benefit of while and conditional directly.
|
||||||
// maximum.
|
if (use.instruction->opcode() == HloOpcode::kWhile ||
|
||||||
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
|
use.instruction->opcode() == HloOpcode::kConditional) {
|
||||||
alternate_mem_benefit += use_alternate_mem_benefit;
|
continue;
|
||||||
} else {
|
}
|
||||||
alternate_mem_benefit =
|
float use_alternate_mem_benefit =
|
||||||
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
|
GetAlternateMemoryBenefit(*use.instruction,
|
||||||
|
GetInstructionElapsedDueToMemory(
|
||||||
|
*use.instruction, use.operand_number),
|
||||||
|
cache);
|
||||||
|
// If the benefit is positive (memory bound), add it to this buffer's
|
||||||
|
// benefit. If the benefit is negative (compute bound), calculate the
|
||||||
|
// maximum.
|
||||||
|
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
|
||||||
|
alternate_mem_benefit += use_alternate_mem_benefit;
|
||||||
|
} else {
|
||||||
|
alternate_mem_benefit =
|
||||||
|
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,17 +124,9 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||||
float alternate_mem_slowdown =
|
float alternate_mem_slowdown =
|
||||||
GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
||||||
|
|
||||||
// Scale the slowdown based on the time of this buffer. We would want earlier
|
// Divide by the size of the buffer to prioritize smaller buffers that will
|
||||||
// buffers have lower slowdown values, because they are less likely to overlap
|
// give the largest alternate memory benefit.
|
||||||
// with other HLOs.
|
return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size;
|
||||||
// TODO(yuemmawang): We may want a piecewise function, where a lower slowdown
|
|
||||||
// for early HLOs, and full slowdown for mid-to-late HLOs.
|
|
||||||
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
|
|
||||||
// more HLOs have higher slowdown, and vice versa.
|
|
||||||
float scale = interval.start * 1.0 / GetScheduleEndTime();
|
|
||||||
alternate_mem_slowdown *= scale;
|
|
||||||
|
|
||||||
return alternate_mem_benefit - alternate_mem_slowdown;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
|
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
|
||||||
|
@ -113,7 +134,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
|
||||||
int nest_level = 0;
|
int nest_level = 0;
|
||||||
const HloComputation* computation = instruction->parent();
|
const HloComputation* computation = instruction->parent();
|
||||||
while (!computation->IsEntryComputation()) {
|
while (!computation->IsEntryComputation()) {
|
||||||
auto node = call_graph_.GetNode(computation);
|
auto node = call_graph_->GetNode(computation);
|
||||||
auto callsites = node.caller_callsites();
|
auto callsites = node.caller_callsites();
|
||||||
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
|
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
|
||||||
auto callsite = callsites[0];
|
auto callsite = callsites[0];
|
||||||
|
@ -195,7 +216,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
|
int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
|
||||||
return hlo_live_range_.schedule_end_time();
|
return hlo_live_range_->schedule_end_time();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
|
bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
|
||||||
|
@ -253,6 +274,13 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
|
||||||
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
|
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
|
||||||
0.0);
|
0.0);
|
||||||
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
|
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
|
||||||
|
// To avoid double counting, don't include the elapsed time of while and
|
||||||
|
// conditional HLOs.
|
||||||
|
const HloInstruction* instruction = instruction_and_logical_time.first;
|
||||||
|
if (instruction->opcode() == HloOpcode::kWhile ||
|
||||||
|
instruction->opcode() == HloOpcode::kConditional) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
|
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
|
||||||
*instruction_and_logical_time.first);
|
*instruction_and_logical_time.first);
|
||||||
int64 logical_time = instruction_and_logical_time.second;
|
int64 logical_time = instruction_and_logical_time.second;
|
||||||
|
@ -1937,17 +1965,38 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
|
||||||
BufferInterval* alternate_mem_interval) const {
|
BufferInterval* alternate_mem_interval) const {
|
||||||
int64 end_time = request.end_time;
|
int64 end_time = request.end_time;
|
||||||
if (!preferred_offset) {
|
if (!preferred_offset) {
|
||||||
|
// First find the earliest use that is the same or later than the end time.
|
||||||
|
const auto& uses = request.allocation_value->uses();
|
||||||
|
auto use_it = uses.begin();
|
||||||
|
for (; use_it->time < end_time; ++use_it) {
|
||||||
|
}
|
||||||
|
CHECK(use_it != uses.end());
|
||||||
|
int64 earliest_use = use_it->time;
|
||||||
|
|
||||||
|
// Then find the latest use that can be allocated contiguously without
|
||||||
|
// copies.
|
||||||
|
const Shape& shape = request.allocation_value->defining_position().shape();
|
||||||
|
for (;
|
||||||
|
(use_it + 1) != uses.end() &&
|
||||||
|
options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
|
||||||
|
shape, use_it->time, (use_it + 1)->time);
|
||||||
|
++use_it) {
|
||||||
|
}
|
||||||
|
CHECK(use_it != uses.end());
|
||||||
|
int64 latest_contiguous_use = use_it->time;
|
||||||
|
|
||||||
// Find a chunk that's as long living as possible iterating in reverse over
|
// Find a chunk that's as long living as possible iterating in reverse over
|
||||||
// the use times.
|
// the use times.
|
||||||
for (auto use_it = request.allocation_value->uses().rbegin();
|
for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) {
|
||||||
use_it != request.allocation_value->uses().rend() &&
|
|
||||||
use_it->time >= end_time;
|
|
||||||
++use_it) {
|
|
||||||
alternate_mem_interval->end = use_it->time;
|
alternate_mem_interval->end = use_it->time;
|
||||||
ChunkCandidate chunk_candidate =
|
ChunkCandidate chunk_candidate =
|
||||||
FindChunkCandidate(*alternate_mem_interval);
|
FindChunkCandidate(*alternate_mem_interval);
|
||||||
if (chunk_candidate.heap_size <= available_heap_size()) {
|
if (chunk_candidate.heap_size <= available_heap_size()) {
|
||||||
alternate_mem_interval->end = end_time;
|
alternate_mem_interval->end = end_time;
|
||||||
|
VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
|
||||||
|
<< ", latest contiguous use = " << latest_contiguous_use
|
||||||
|
<< ", use with available mem = " << use_it->time
|
||||||
|
<< ", offset = " << chunk_candidate.chunk.offset;
|
||||||
return chunk_candidate;
|
return chunk_candidate;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2005,8 +2054,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
|
||||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
|
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
|
||||||
return [cost_analysis, cache](const BufferInterval& x,
|
return [&cost_analysis, cache](const BufferInterval& x,
|
||||||
const BufferInterval& y) {
|
const BufferInterval& y) {
|
||||||
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
|
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
|
||||||
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
|
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
|
||||||
if (x_memory_boundedness != y_memory_boundedness) {
|
if (x_memory_boundedness != y_memory_boundedness) {
|
||||||
|
|
|
@ -84,18 +84,10 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||||
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
|
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
|
||||||
};
|
};
|
||||||
|
|
||||||
MemorySpaceAssignmentCostAnalysis(
|
static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
|
||||||
const HloCostAnalysis& cost_analysis,
|
const HloCostAnalysis& cost_analysis,
|
||||||
float async_copy_bandwidth_bytes_per_second,
|
float async_copy_bandwidth_bytes_per_second,
|
||||||
float alternate_mem_bandwidth_bytes_per_second,
|
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module);
|
||||||
const HloLiveRange& hlo_live_range, const CallGraph& call_graph)
|
|
||||||
: cost_analysis_(cost_analysis),
|
|
||||||
async_copy_bandwidth_bytes_per_second_(
|
|
||||||
async_copy_bandwidth_bytes_per_second),
|
|
||||||
alternate_mem_bandwidth_bytes_per_second_(
|
|
||||||
alternate_mem_bandwidth_bytes_per_second),
|
|
||||||
hlo_live_range_(hlo_live_range),
|
|
||||||
call_graph_(call_graph) {}
|
|
||||||
|
|
||||||
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
|
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
|
||||||
|
|
||||||
|
@ -153,14 +145,31 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||||
// 0 means it is not in a while loop.
|
// 0 means it is not in a while loop.
|
||||||
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
|
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
|
||||||
|
|
||||||
const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
|
const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
MemorySpaceAssignmentCostAnalysis(
|
||||||
|
const HloCostAnalysis& cost_analysis,
|
||||||
|
float async_copy_bandwidth_bytes_per_second,
|
||||||
|
float alternate_mem_bandwidth_bytes_per_second,
|
||||||
|
std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
|
std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||||
|
std::unique_ptr<CallGraph> call_graph)
|
||||||
|
: cost_analysis_(cost_analysis),
|
||||||
|
async_copy_bandwidth_bytes_per_second_(
|
||||||
|
async_copy_bandwidth_bytes_per_second),
|
||||||
|
alternate_mem_bandwidth_bytes_per_second_(
|
||||||
|
alternate_mem_bandwidth_bytes_per_second),
|
||||||
|
alias_analysis_(std::move(alias_analysis)),
|
||||||
|
hlo_live_range_(std::move(hlo_live_range)),
|
||||||
|
call_graph_(std::move(call_graph)) {}
|
||||||
|
|
||||||
const HloCostAnalysis& cost_analysis_;
|
const HloCostAnalysis& cost_analysis_;
|
||||||
float async_copy_bandwidth_bytes_per_second_;
|
float async_copy_bandwidth_bytes_per_second_;
|
||||||
float alternate_mem_bandwidth_bytes_per_second_;
|
float alternate_mem_bandwidth_bytes_per_second_;
|
||||||
const HloLiveRange& hlo_live_range_;
|
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
|
||||||
const CallGraph& call_graph_;
|
std::unique_ptr<HloLiveRange> hlo_live_range_;
|
||||||
|
std::unique_ptr<CallGraph> call_graph_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Abstract base class that memory space assignment uses to pick prefetch
|
// Abstract base class that memory space assignment uses to pick prefetch
|
||||||
|
|
|
@ -53,22 +53,18 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||||
TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
|
TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
|
||||||
}
|
}
|
||||||
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
|
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
|
||||||
std::unique_ptr<HloLiveRange> hlo_live_range =
|
auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create(
|
||||||
HloLiveRange::Run(module->schedule(), *alias_analysis,
|
hlo_cost_analysis, kAsyncCopyBandwidth,
|
||||||
module->entry_computation())
|
kAlternateMemBandwidth, *module)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
|
||||||
MemorySpaceAssignmentCostAnalysis cost_analysis(
|
|
||||||
hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth,
|
|
||||||
*hlo_live_range, *call_graph);
|
|
||||||
CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
|
CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
|
||||||
CostAnalysisPrefetchIntervalPicker(
|
CostAnalysisPrefetchIntervalPicker(
|
||||||
cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
|
*cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
|
||||||
/*max_async_copy_to_overlap_ratio=*/10.0));
|
/*max_async_copy_to_overlap_ratio=*/10.0));
|
||||||
return AssignMemorySpace(
|
return AssignMemorySpace(
|
||||||
module, /*max_outstanding_async_copies=*/-1,
|
module, /*max_outstanding_async_copies=*/-1,
|
||||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||||
cost_analysis, &cache_),
|
*cost_analysis, &cache_),
|
||||||
&prefetch_interval_picker);
|
&prefetch_interval_picker);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue