[XLA] Improve cost analysis for while loops.

In order to prioritize alternate memory allocations for HLOs in (nested) while
loops, the cost model now accounts for these instructions as executing a
heuristic constant number of times (currently 5). Nested while loops will be
calculated to have executed pow(5, nesting_level) times.

PiperOrigin-RevId: 312288904
Change-Id: Ibb177ac971922e0660cd0385f1b38d223804d0c9
This commit is contained in:
Berkin Ilbeyi 2020-05-19 08:54:59 -07:00 committed by TensorFlower Gardener
parent 9bdd084064
commit be4980e340
4 changed files with 103 additions and 57 deletions

View File

@ -3284,6 +3284,7 @@ cc_library(
":heap_simulator",
":hlo_cost_analysis",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/core/lib/math:math_util",
],
)

View File

@ -16,53 +16,52 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace xla {
namespace {
// Define a dummy chunk for chunks that will be allocated in the default memory
// space and for keeping track of number of asynchronous copies.
const HeapSimulator::Chunk kDummyChunk{-1, -1};
// This variable is used by the cost analysis in estimating how many times each
// while loop will execute. Nested loops will be assumed to have executed
// pow(kWhileExecutionCount, nesting_level) times.
const int kWhileExecutionCount = 5;
// Returns a heuristic value that captures how much putting this tensor to
// the alternate memory would help if the op is memory bound, or otherwise
// how far off is the op to memory boundedness. The larger this number, the
// higher priority it will be placed in the alternate memory.
float GetAlternateMemoryBenefit(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
} // namespace
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
const HloInstruction& instruction,
float elapsed_time_due_to_alternate_mem) {
float elapsed_time_due_to_alternate_mem) const {
float elapsed_time_due_to_compute =
cost_analysis.GetInstructionElapsedDueToCompute(instruction);
GetInstructionElapsedDueToCompute(instruction);
float elapsed_time_due_to_memory =
cost_analysis.GetInstructionElapsedDueToMemory(instruction);
GetInstructionElapsedDueToMemory(instruction);
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
// Memory bound, return how much alternate memory is better.
return elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem;
int while_nest_level = CalculateWhileLoopNestLevel(&instruction);
return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
while_nest_level);
} else {
// Compute bound, return how far off are we to memory boundedness.
return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
}
}
// Returns a heuristic value of memory boundedness for the given BufferInterval.
// The larger this number, the higher priority it will be placed in the
// alternate memory.
float GetMemoryBoundedness(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
const HloInstruction& defining_instruction =
*interval.buffer->defining_instruction();
float alternate_mem_benefit =
GetAlternateMemoryBenefit(cost_analysis, defining_instruction,
cost_analysis.GetInstructionElapsedDueToMemory(
float alternate_mem_benefit = GetAlternateMemoryBenefit(
defining_instruction,
GetInstructionElapsedDueToMemory(defining_instruction,
/*operand_in_alternate_mem=*/{},
/*output_in_alternate_mem=*/true));
for (const HloUse& use : interval.buffer->uses()) {
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
cost_analysis, *use.instruction,
cost_analysis.GetInstructionElapsedDueToMemory(*use.instruction,
use.operand_number));
*use.instruction,
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number));
// If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the
// maximum.
@ -77,7 +76,7 @@ float GetMemoryBoundedness(
// Get performance slowdown in seconds of prefetching current BufferInterval
// causing to other BufferIntervals.
float alternate_mem_slowdown =
cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size);
GetInstructionElapsedDueToMemorySlowdown(interval.size);
// Scale the slowdown based on the time of this buffer. We would want earlier
// buffers have lower slowdown values, because they are less likely to overlap
@ -86,13 +85,28 @@ float GetMemoryBoundedness(
// for early HLOs, and full slowdown for mid-to-late HLOs.
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
// more HLOs have higher slowdown, and vice versa.
float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime();
float scale = interval.start * 1.0 / GetScheduleEndTime();
alternate_mem_slowdown *= scale;
return alternate_mem_benefit - alternate_mem_slowdown;
}
} // namespace
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
const HloInstruction* instruction) const {
int nest_level = 0;
const HloComputation* computation = instruction->parent();
while (!computation->IsEntryComputation()) {
auto node = call_graph_.GetNode(computation);
auto callsites = node.caller_callsites();
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
auto callsite = callsites[0];
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
++nest_level;
}
computation = callsite.instruction()->parent();
}
return nest_level;
}
float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
const HloInstruction& instruction) const {
@ -207,29 +221,30 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
float min_async_copy_to_overlap_ratio,
float max_async_copy_to_overlap_ratio)
: cost_analysis_(cost_analysis),
: elapsed_time_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0),
while_nest_level_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
cost_analysis_(cost_analysis),
min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) {
instruction_schedule_ =
&cost_analysis_.hlo_live_range().instruction_schedule();
// First create a vector of elapsed times of HLO instructions.
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
0.0);
// Create a vector of elapsed times and while nesting levels of HLO
// instructions.
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
*instruction_and_logical_time.first);
int64 logical_time = instruction_and_logical_time.second;
if (logical_time >= instructions_elapsed_time.size()) {
instructions_elapsed_time.resize(logical_time + 1, 0.0);
if (logical_time >= elapsed_time_.size()) {
elapsed_time_.resize(logical_time + 1, 0.0);
while_nest_level_.resize(logical_time + 1, 0);
}
instructions_elapsed_time[logical_time] = elapsed_time;
}
// As an optimization, create a cumulative sum vector of elapsed time.
float cumsum = 0.0;
for (float elapsed_time : instructions_elapsed_time) {
cumsum += elapsed_time;
elapsed_time_cumsum_.push_back(cumsum);
elapsed_time_[logical_time] = elapsed_time;
while_nest_level_[logical_time] =
cost_analysis_.CalculateWhileLoopNestLevel(
instruction_and_logical_time.first);
}
}
@ -303,7 +318,17 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
int64 start_time, int64 end_time) const {
return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time];
int interval_nest_level =
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
float total_elapsed = 0;
for (int i = start_time + 1; i < end_time; ++i) {
total_elapsed +=
elapsed_time_[i] *
tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount,
std::max(0, while_nest_level_[i] - interval_nest_level));
}
return total_elapsed;
}
std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
@ -328,7 +353,7 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
absl::optional<float>
CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
return GetMemoryBoundedness(cost_analysis_, interval);
return cost_analysis_.GetMemoryBoundedness(interval);
}
std::string MemorySpaceAssignment::AllocationValue::ToString() const {
@ -805,8 +830,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
global_max_time_ = instruction_schedule.at(
module->entry_computation()->root_instruction());
// TODO(berkin): For now, place the phi values due to conditionals in
// default memory.
@ -1609,6 +1632,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
request.allocation_value->defining_position().shape(),
eviction_start_time, request.end_time),
eviction_end_time);
// Evictions must complete by the time of this use.
preferred_eviction_end_time =
std::min(preferred_eviction_end_time, request.latest_prefetch_time);
BufferInterval eviction_mem_interval;
eviction_mem_interval.buffer = request.allocation_value->value();
@ -1616,8 +1642,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
// Try to reserve a buffer from the end of the previous allocation to the
// preferred eviction end time.
eviction_mem_interval.start = eviction_end_time + 1;
eviction_mem_interval.end =
std::min(preferred_eviction_end_time, global_max_time_);
eviction_mem_interval.end = preferred_eviction_end_time;
int64 preferred_offset = prev_allocation->chunk().offset;
VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
<< ") preferred end time = " << eviction_mem_interval.end;
@ -1834,8 +1859,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
return [&](const BufferInterval& x, const BufferInterval& y) {
float x_memory_boundedness = GetMemoryBoundedness(cost_analysis, x);
float y_memory_boundedness = GetMemoryBoundedness(cost_analysis, y);
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x);
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y);
if (x_memory_boundedness != y_memory_boundedness) {
return x_memory_boundedness > y_memory_boundedness;
}

View File

@ -82,16 +82,31 @@ class MemorySpaceAssignmentCostAnalysis {
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
const HloLiveRange& hlo_live_range)
const HloLiveRange& hlo_live_range, const CallGraph& call_graph)
: cost_analysis_(cost_analysis),
async_copy_bandwidth_bytes_per_second_(
async_copy_bandwidth_bytes_per_second),
alternate_mem_bandwidth_bytes_per_second_(
alternate_mem_bandwidth_bytes_per_second),
hlo_live_range_(hlo_live_range) {}
hlo_live_range_(hlo_live_range),
call_graph_(call_graph) {}
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
// Returns a heuristic value that captures how much putting this tensor to the
// alternate memory would help if the op is memory bound, or otherwise how far
// off is the op to memory boundedness. The larger this number, the higher
// priority it will be placed in the alternate memory.
float GetAlternateMemoryBenefit(
const HloInstruction& instruction,
float elapsed_time_due_to_alternate_mem) const;
// Returns a heuristic value of memory boundedness for the given
// BufferInterval. The larger this number, the higher priority it will be
// placed in the alternate memory.
float GetMemoryBoundedness(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const;
// Returns the elapsed time in seconds due to compute only.
float GetInstructionElapsedDueToCompute(
const HloInstruction& instruction) const;
@ -127,6 +142,10 @@ class MemorySpaceAssignmentCostAnalysis {
int64 GetScheduleEndTime() const;
// Returns the number of nested while loop levels this instruction resides in.
// 0 means it is not in a while loop.
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
private:
@ -134,6 +153,7 @@ class MemorySpaceAssignmentCostAnalysis {
float async_copy_bandwidth_bytes_per_second_;
float alternate_mem_bandwidth_bytes_per_second_;
const HloLiveRange& hlo_live_range_;
const CallGraph& call_graph_;
};
// Abstract base class that memory space assignment uses to pick prefetch
@ -262,10 +282,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
// corresponds to the instruction schedule.
float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
// For performance reasons, we calculate the prefix sum of the elapsed time so
// that it's efficient to find the elapsed time in seconds in any logical
// interval.
std::vector<float> elapsed_time_cumsum_;
// For each instruction in the flattened schedule, maintain their elapsed time
// and while nesting level.
std::vector<float> elapsed_time_;
std::vector<int> while_nest_level_;
const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
float min_async_copy_to_overlap_ratio_;
@ -988,7 +1008,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
required_assignments_;
// Number of bytes reserved in alternate memory space.
int64 reserved_in_bytes_ = 0;
int64 global_max_time_;
};
} // namespace xla

View File

@ -57,9 +57,10 @@ class MemorySpaceAssignmentTest : public HloTestBase,
HloLiveRange::Run(module->schedule(), *alias_analysis,
module->entry_computation())
.ValueOrDie();
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
MemorySpaceAssignmentCostAnalysis cost_analysis(
hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth,
*hlo_live_range);
*hlo_live_range, *call_graph);
CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
CostAnalysisPrefetchIntervalPicker(
cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,