[XLA] Improve cost analysis for while loops.
In order to prioritize alternate memory allocations for HLOs in (nested) while loops, the cost model now accounts for these instructions as executing a heuristic constant number of times (currently 5). Nested while loops will be calculated to have executed pow(5, nesting_level) times. PiperOrigin-RevId: 312288904 Change-Id: Ibb177ac971922e0660cd0385f1b38d223804d0c9
This commit is contained in:
parent
9bdd084064
commit
be4980e340
@ -3284,6 +3284,7 @@ cc_library(
|
||||
":heap_simulator",
|
||||
":hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/core/lib/math:math_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,53 +16,52 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
// Define a dummy chunk for chunks that will be allocated in the default memory
|
||||
// space and for keeping track of number of asynchronous copies.
|
||||
const HeapSimulator::Chunk kDummyChunk{-1, -1};
|
||||
// This variable is used by the cost analysis in estimating how many times each
|
||||
// while loop will execute. Nested loops will be assumed to have executed
|
||||
// pow(kWhileExecutionCount, nesting_level) times.
|
||||
const int kWhileExecutionCount = 5;
|
||||
|
||||
// Returns a heuristic value that captures how much putting this tensor to
|
||||
// the alternate memory would help if the op is memory bound, or otherwise
|
||||
// how far off is the op to memory boundedness. The larger this number, the
|
||||
// higher priority it will be placed in the alternate memory.
|
||||
float GetAlternateMemoryBenefit(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
} // namespace
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
||||
const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem) {
|
||||
float elapsed_time_due_to_alternate_mem) const {
|
||||
float elapsed_time_due_to_compute =
|
||||
cost_analysis.GetInstructionElapsedDueToCompute(instruction);
|
||||
GetInstructionElapsedDueToCompute(instruction);
|
||||
float elapsed_time_due_to_memory =
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(instruction);
|
||||
GetInstructionElapsedDueToMemory(instruction);
|
||||
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
|
||||
// Memory bound, return how much alternate memory is better.
|
||||
return elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem;
|
||||
int while_nest_level = CalculateWhileLoopNestLevel(&instruction);
|
||||
return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
|
||||
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
|
||||
while_nest_level);
|
||||
} else {
|
||||
// Compute bound, return how far off are we to memory boundedness.
|
||||
return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a heuristic value of memory boundedness for the given BufferInterval.
|
||||
// The larger this number, the higher priority it will be placed in the
|
||||
// alternate memory.
|
||||
float GetMemoryBoundedness(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
|
||||
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
const HloInstruction& defining_instruction =
|
||||
*interval.buffer->defining_instruction();
|
||||
float alternate_mem_benefit =
|
||||
GetAlternateMemoryBenefit(cost_analysis, defining_instruction,
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(
|
||||
float alternate_mem_benefit = GetAlternateMemoryBenefit(
|
||||
defining_instruction,
|
||||
GetInstructionElapsedDueToMemory(defining_instruction,
|
||||
/*operand_in_alternate_mem=*/{},
|
||||
/*output_in_alternate_mem=*/true));
|
||||
for (const HloUse& use : interval.buffer->uses()) {
|
||||
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
|
||||
cost_analysis, *use.instruction,
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(*use.instruction,
|
||||
use.operand_number));
|
||||
*use.instruction,
|
||||
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number));
|
||||
// If the benefit is positive (memory bound), add it to this buffer's
|
||||
// benefit. If the benefit is negative (compute bound), calculate the
|
||||
// maximum.
|
||||
@ -77,7 +76,7 @@ float GetMemoryBoundedness(
|
||||
// Get performance slowdown in seconds of prefetching current BufferInterval
|
||||
// causing to other BufferIntervals.
|
||||
float alternate_mem_slowdown =
|
||||
cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
||||
GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
||||
|
||||
// Scale the slowdown based on the time of this buffer. We would want earlier
|
||||
// buffers have lower slowdown values, because they are less likely to overlap
|
||||
@ -86,13 +85,28 @@ float GetMemoryBoundedness(
|
||||
// for early HLOs, and full slowdown for mid-to-late HLOs.
|
||||
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
|
||||
// more HLOs have higher slowdown, and vice versa.
|
||||
float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime();
|
||||
float scale = interval.start * 1.0 / GetScheduleEndTime();
|
||||
alternate_mem_slowdown *= scale;
|
||||
|
||||
return alternate_mem_benefit - alternate_mem_slowdown;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
|
||||
const HloInstruction* instruction) const {
|
||||
int nest_level = 0;
|
||||
const HloComputation* computation = instruction->parent();
|
||||
while (!computation->IsEntryComputation()) {
|
||||
auto node = call_graph_.GetNode(computation);
|
||||
auto callsites = node.caller_callsites();
|
||||
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
|
||||
auto callsite = callsites[0];
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
++nest_level;
|
||||
}
|
||||
computation = callsite.instruction()->parent();
|
||||
}
|
||||
return nest_level;
|
||||
}
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
|
||||
const HloInstruction& instruction) const {
|
||||
@ -207,29 +221,30 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
float min_async_copy_to_overlap_ratio,
|
||||
float max_async_copy_to_overlap_ratio)
|
||||
: cost_analysis_(cost_analysis),
|
||||
: elapsed_time_(
|
||||
cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0),
|
||||
while_nest_level_(
|
||||
cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
|
||||
cost_analysis_(cost_analysis),
|
||||
min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
|
||||
max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) {
|
||||
instruction_schedule_ =
|
||||
&cost_analysis_.hlo_live_range().instruction_schedule();
|
||||
|
||||
// First create a vector of elapsed times of HLO instructions.
|
||||
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
|
||||
0.0);
|
||||
// Create a vector of elapsed times and while nesting levels of HLO
|
||||
// instructions.
|
||||
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
|
||||
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
|
||||
*instruction_and_logical_time.first);
|
||||
int64 logical_time = instruction_and_logical_time.second;
|
||||
if (logical_time >= instructions_elapsed_time.size()) {
|
||||
instructions_elapsed_time.resize(logical_time + 1, 0.0);
|
||||
if (logical_time >= elapsed_time_.size()) {
|
||||
elapsed_time_.resize(logical_time + 1, 0.0);
|
||||
while_nest_level_.resize(logical_time + 1, 0);
|
||||
}
|
||||
instructions_elapsed_time[logical_time] = elapsed_time;
|
||||
}
|
||||
// As an optimization, create a cumulative sum vector of elapsed time.
|
||||
float cumsum = 0.0;
|
||||
for (float elapsed_time : instructions_elapsed_time) {
|
||||
cumsum += elapsed_time;
|
||||
elapsed_time_cumsum_.push_back(cumsum);
|
||||
elapsed_time_[logical_time] = elapsed_time;
|
||||
while_nest_level_[logical_time] =
|
||||
cost_analysis_.CalculateWhileLoopNestLevel(
|
||||
instruction_and_logical_time.first);
|
||||
}
|
||||
}
|
||||
|
||||
@ -303,7 +318,17 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
|
||||
|
||||
float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
|
||||
int64 start_time, int64 end_time) const {
|
||||
return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time];
|
||||
int interval_nest_level =
|
||||
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
|
||||
float total_elapsed = 0;
|
||||
for (int i = start_time + 1; i < end_time; ++i) {
|
||||
total_elapsed +=
|
||||
elapsed_time_[i] *
|
||||
tensorflow::MathUtil::IPow<float>(
|
||||
kWhileExecutionCount,
|
||||
std::max(0, while_nest_level_[i] - interval_nest_level));
|
||||
}
|
||||
return total_elapsed;
|
||||
}
|
||||
|
||||
std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
|
||||
@ -328,7 +353,7 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
|
||||
absl::optional<float>
|
||||
CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
return GetMemoryBoundedness(cost_analysis_, interval);
|
||||
return cost_analysis_.GetMemoryBoundedness(interval);
|
||||
}
|
||||
|
||||
std::string MemorySpaceAssignment::AllocationValue::ToString() const {
|
||||
@ -805,8 +830,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
}
|
||||
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
global_max_time_ = instruction_schedule.at(
|
||||
module->entry_computation()->root_instruction());
|
||||
|
||||
// TODO(berkin): For now, place the phi values due to conditionals in
|
||||
// default memory.
|
||||
@ -1609,6 +1632,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
|
||||
request.allocation_value->defining_position().shape(),
|
||||
eviction_start_time, request.end_time),
|
||||
eviction_end_time);
|
||||
// Evictions must complete by the time of this use.
|
||||
preferred_eviction_end_time =
|
||||
std::min(preferred_eviction_end_time, request.latest_prefetch_time);
|
||||
|
||||
BufferInterval eviction_mem_interval;
|
||||
eviction_mem_interval.buffer = request.allocation_value->value();
|
||||
@ -1616,8 +1642,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
|
||||
// Try to reserve a buffer from the end of the previous allocation to the
|
||||
// preferred eviction end time.
|
||||
eviction_mem_interval.start = eviction_end_time + 1;
|
||||
eviction_mem_interval.end =
|
||||
std::min(preferred_eviction_end_time, global_max_time_);
|
||||
eviction_mem_interval.end = preferred_eviction_end_time;
|
||||
int64 preferred_offset = prev_allocation->chunk().offset;
|
||||
VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
|
||||
<< ") preferred end time = " << eviction_mem_interval.end;
|
||||
@ -1834,8 +1859,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
|
||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
|
||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||
float x_memory_boundedness = GetMemoryBoundedness(cost_analysis, x);
|
||||
float y_memory_boundedness = GetMemoryBoundedness(cost_analysis, y);
|
||||
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x);
|
||||
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y);
|
||||
if (x_memory_boundedness != y_memory_boundedness) {
|
||||
return x_memory_boundedness > y_memory_boundedness;
|
||||
}
|
||||
|
@ -82,16 +82,31 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
float alternate_mem_bandwidth_bytes_per_second,
|
||||
const HloLiveRange& hlo_live_range)
|
||||
const HloLiveRange& hlo_live_range, const CallGraph& call_graph)
|
||||
: cost_analysis_(cost_analysis),
|
||||
async_copy_bandwidth_bytes_per_second_(
|
||||
async_copy_bandwidth_bytes_per_second),
|
||||
alternate_mem_bandwidth_bytes_per_second_(
|
||||
alternate_mem_bandwidth_bytes_per_second),
|
||||
hlo_live_range_(hlo_live_range) {}
|
||||
hlo_live_range_(hlo_live_range),
|
||||
call_graph_(call_graph) {}
|
||||
|
||||
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
|
||||
|
||||
// Returns a heuristic value that captures how much putting this tensor to the
|
||||
// alternate memory would help if the op is memory bound, or otherwise how far
|
||||
// off is the op to memory boundedness. The larger this number, the higher
|
||||
// priority it will be placed in the alternate memory.
|
||||
float GetAlternateMemoryBenefit(
|
||||
const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem) const;
|
||||
|
||||
// Returns a heuristic value of memory boundedness for the given
|
||||
// BufferInterval. The larger this number, the higher priority it will be
|
||||
// placed in the alternate memory.
|
||||
float GetMemoryBoundedness(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const;
|
||||
|
||||
// Returns the elapsed time in seconds due to compute only.
|
||||
float GetInstructionElapsedDueToCompute(
|
||||
const HloInstruction& instruction) const;
|
||||
@ -127,6 +142,10 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||
|
||||
int64 GetScheduleEndTime() const;
|
||||
|
||||
// Returns the number of nested while loop levels this instruction resides in.
|
||||
// 0 means it is not in a while loop.
|
||||
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
|
||||
|
||||
const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
|
||||
|
||||
private:
|
||||
@ -134,6 +153,7 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||
float async_copy_bandwidth_bytes_per_second_;
|
||||
float alternate_mem_bandwidth_bytes_per_second_;
|
||||
const HloLiveRange& hlo_live_range_;
|
||||
const CallGraph& call_graph_;
|
||||
};
|
||||
|
||||
// Abstract base class that memory space assignment uses to pick prefetch
|
||||
@ -262,10 +282,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
// corresponds to the instruction schedule.
|
||||
float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
|
||||
|
||||
// For performance reasons, we calculate the prefix sum of the elapsed time so
|
||||
// that it's efficient to find the elapsed time in seconds in any logical
|
||||
// interval.
|
||||
std::vector<float> elapsed_time_cumsum_;
|
||||
// For each instruction in the flattened schedule, maintain their elapsed time
|
||||
// and while nesting level.
|
||||
std::vector<float> elapsed_time_;
|
||||
std::vector<int> while_nest_level_;
|
||||
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
|
||||
float min_async_copy_to_overlap_ratio_;
|
||||
@ -988,7 +1008,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
required_assignments_;
|
||||
// Number of bytes reserved in alternate memory space.
|
||||
int64 reserved_in_bytes_ = 0;
|
||||
int64 global_max_time_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -57,9 +57,10 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||
HloLiveRange::Run(module->schedule(), *alias_analysis,
|
||||
module->entry_computation())
|
||||
.ValueOrDie();
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
MemorySpaceAssignmentCostAnalysis cost_analysis(
|
||||
hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth,
|
||||
*hlo_live_range);
|
||||
*hlo_live_range, *call_graph);
|
||||
CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
|
||||
CostAnalysisPrefetchIntervalPicker(
|
||||
cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
|
||||
|
Loading…
Reference in New Issue
Block a user