[XLA] More tweaks and fixes to memory space assignment:

- When sorting the BufferIntervals, now we also correctly include aliased
  HloValues into the benefit computation.
- To avoid double counting, we now skip over while and conditional HLOs when
  using cost analysis since we already look inside the called computations.
- When trying to find a preferred alternate memory offset, we now find the
  latest use that can be contiguously allocated (determined by the max overlap
  ratio heuristic) instead of trying to find an allocation that is as
  long-living as possible. This should improve fragmentation slightly.

PiperOrigin-RevId: 316777648
Change-Id: If9d40a79283b644db975ebe62b1bb6c545fea89d
This commit is contained in:
Berkin Ilbeyi 2020-06-16 16:12:25 -07:00 committed by TensorFlower Gardener
parent a6945b9b0f
commit fad7b3a33b
3 changed files with 109 additions and 55 deletions

View File

@ -31,6 +31,22 @@ const int kWhileExecutionCount = 5;
} // namespace
/*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
MemorySpaceAssignmentCostAnalysis::Create(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) {
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
HloLiveRange::Run(module.schedule(), *alias_analysis,
module.entry_computation()));
auto call_graph = CallGraph::Build(&module);
return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
cost_analysis, async_copy_bandwidth_bytes_per_second,
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
std::move(hlo_live_range), std::move(call_graph)));
}
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
@ -74,19 +90,32 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
/*operand_in_alternate_mem=*/{},
/*output_in_alternate_mem=*/true),
cache);
for (const HloUse& use : interval.buffer->uses()) {
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
*use.instruction,
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
cache);
// If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the
// maximum.
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
alternate_mem_benefit += use_alternate_mem_benefit;
} else {
alternate_mem_benefit =
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
interval.buffer->defining_position().instruction,
interval.buffer->defining_position().index)) {
for (const HloValue* value : buffer->values()) {
for (const HloUse& use : value->uses()) {
// We look inside the called computations of while and conditional, so
// don't use the benefit of while and conditional directly.
if (use.instruction->opcode() == HloOpcode::kWhile ||
use.instruction->opcode() == HloOpcode::kConditional) {
continue;
}
float use_alternate_mem_benefit =
GetAlternateMemoryBenefit(*use.instruction,
GetInstructionElapsedDueToMemory(
*use.instruction, use.operand_number),
cache);
// If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the
// maximum.
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
alternate_mem_benefit += use_alternate_mem_benefit;
} else {
alternate_mem_benefit =
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
}
}
}
}
@ -95,17 +124,9 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
float alternate_mem_slowdown =
GetInstructionElapsedDueToMemorySlowdown(interval.size);
// Scale the slowdown based on the time of this buffer. We would want earlier
// buffers have lower slowdown values, because they are less likely to overlap
// with other HLOs.
// TODO(yuemmawang): We may want a piecewise function, where a lower slowdown
// for early HLOs, and full slowdown for mid-to-late HLOs.
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
// more HLOs have higher slowdown, and vice versa.
float scale = interval.start * 1.0 / GetScheduleEndTime();
alternate_mem_slowdown *= scale;
return alternate_mem_benefit - alternate_mem_slowdown;
// Divide by the size of the buffer to prioritize smaller buffers that will
// give the largest alternate memory benefit.
return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size;
}
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
@ -113,7 +134,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
int nest_level = 0;
const HloComputation* computation = instruction->parent();
while (!computation->IsEntryComputation()) {
auto node = call_graph_.GetNode(computation);
auto node = call_graph_->GetNode(computation);
auto callsites = node.caller_callsites();
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
auto callsite = callsites[0];
@ -195,7 +216,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
}
int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
return hlo_live_range_.schedule_end_time();
return hlo_live_range_->schedule_end_time();
}
bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
@ -253,6 +274,13 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
0.0);
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
// To avoid double counting, don't include the elapsed time of while and
// conditional HLOs.
const HloInstruction* instruction = instruction_and_logical_time.first;
if (instruction->opcode() == HloOpcode::kWhile ||
instruction->opcode() == HloOpcode::kConditional) {
continue;
}
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
*instruction_and_logical_time.first);
int64 logical_time = instruction_and_logical_time.second;
@ -1937,17 +1965,38 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
BufferInterval* alternate_mem_interval) const {
int64 end_time = request.end_time;
if (!preferred_offset) {
// First find the earliest use that is the same or later than the end time.
const auto& uses = request.allocation_value->uses();
auto use_it = uses.begin();
for (; use_it->time < end_time; ++use_it) {
}
CHECK(use_it != uses.end());
int64 earliest_use = use_it->time;
// Then find the latest use that can be allocated contiguously without
// copies.
const Shape& shape = request.allocation_value->defining_position().shape();
for (;
(use_it + 1) != uses.end() &&
options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, use_it->time, (use_it + 1)->time);
++use_it) {
}
CHECK(use_it != uses.end());
int64 latest_contiguous_use = use_it->time;
// Find a chunk that's as long living as possible iterating in reverse over
// the use times.
for (auto use_it = request.allocation_value->uses().rbegin();
use_it != request.allocation_value->uses().rend() &&
use_it->time >= end_time;
++use_it) {
for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) {
alternate_mem_interval->end = use_it->time;
ChunkCandidate chunk_candidate =
FindChunkCandidate(*alternate_mem_interval);
if (chunk_candidate.heap_size <= available_heap_size()) {
alternate_mem_interval->end = end_time;
VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
<< ", latest contiguous use = " << latest_contiguous_use
<< ", use with available mem = " << use_it->time
<< ", offset = " << chunk_candidate.chunk.offset;
return chunk_candidate;
}
}
@ -2005,8 +2054,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
return [cost_analysis, cache](const BufferInterval& x,
const BufferInterval& y) {
return [&cost_analysis, cache](const BufferInterval& x,
const BufferInterval& y) {
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
if (x_memory_boundedness != y_memory_boundedness) {

View File

@ -84,18 +84,10 @@ class MemorySpaceAssignmentCostAnalysis {
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
};
MemorySpaceAssignmentCostAnalysis(
static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
const HloLiveRange& hlo_live_range, const CallGraph& call_graph)
: cost_analysis_(cost_analysis),
async_copy_bandwidth_bytes_per_second_(
async_copy_bandwidth_bytes_per_second),
alternate_mem_bandwidth_bytes_per_second_(
alternate_mem_bandwidth_bytes_per_second),
hlo_live_range_(hlo_live_range),
call_graph_(call_graph) {}
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module);
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
@ -153,14 +145,31 @@ class MemorySpaceAssignmentCostAnalysis {
// 0 means it is not in a while loop.
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
private:
MemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
std::unique_ptr<HloAliasAnalysis> alias_analysis,
std::unique_ptr<HloLiveRange> hlo_live_range,
std::unique_ptr<CallGraph> call_graph)
: cost_analysis_(cost_analysis),
async_copy_bandwidth_bytes_per_second_(
async_copy_bandwidth_bytes_per_second),
alternate_mem_bandwidth_bytes_per_second_(
alternate_mem_bandwidth_bytes_per_second),
alias_analysis_(std::move(alias_analysis)),
hlo_live_range_(std::move(hlo_live_range)),
call_graph_(std::move(call_graph)) {}
const HloCostAnalysis& cost_analysis_;
float async_copy_bandwidth_bytes_per_second_;
float alternate_mem_bandwidth_bytes_per_second_;
const HloLiveRange& hlo_live_range_;
const CallGraph& call_graph_;
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
std::unique_ptr<HloLiveRange> hlo_live_range_;
std::unique_ptr<CallGraph> call_graph_;
};
// Abstract base class that memory space assignment uses to pick prefetch

View File

@ -53,22 +53,18 @@ class MemorySpaceAssignmentTest : public HloTestBase,
TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
}
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
std::unique_ptr<HloLiveRange> hlo_live_range =
HloLiveRange::Run(module->schedule(), *alias_analysis,
module->entry_computation())
.ValueOrDie();
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
MemorySpaceAssignmentCostAnalysis cost_analysis(
hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth,
*hlo_live_range, *call_graph);
auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create(
hlo_cost_analysis, kAsyncCopyBandwidth,
kAlternateMemBandwidth, *module)
.ValueOrDie();
CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
CostAnalysisPrefetchIntervalPicker(
cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
*cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
/*max_async_copy_to_overlap_ratio=*/10.0));
return AssignMemorySpace(
module, /*max_outstanding_async_copies=*/-1,
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
cost_analysis, &cache_),
*cost_analysis, &cache_),
&prefetch_interval_picker);
}