[XLA] Improve compile time performance of memory space assignment.

PiperOrigin-RevId: 314580604
Change-Id: I08ce28701eee0513bfea067662df3e2ce57d9987
This commit is contained in:
Berkin Ilbeyi 2020-06-03 12:21:52 -07:00 committed by TensorFlower Gardener
parent 929398ef01
commit 50c3196789
3 changed files with 119 additions and 53 deletions

View File

@ -31,18 +31,31 @@ const int kWhileExecutionCount = 5;
} // namespace
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
const HloInstruction& instruction,
float elapsed_time_due_to_alternate_mem) const {
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
float elapsed_time_due_to_compute =
GetInstructionElapsedDueToCompute(instruction);
float elapsed_time_due_to_memory =
GetInstructionElapsedDueToMemory(instruction);
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
// Memory bound, return how much alternate memory is better.
int while_nest_level = CalculateWhileLoopNestLevel(&instruction);
float while_nest_multiplier;
if (cache) {
// If there is a cache provided, memoize the while nest multiplier.
auto it = cache->while_nest_multiplier.find(&instruction);
if (it != cache->while_nest_multiplier.end()) {
while_nest_multiplier = it->second;
} else {
while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction));
cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
}
} else {
while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction));
}
return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
while_nest_level);
while_nest_multiplier;
} else {
// Compute bound, return how far off are we to memory boundedness.
return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
@ -50,18 +63,21 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
}
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
const HloInstruction& defining_instruction =
*interval.buffer->defining_instruction();
float alternate_mem_benefit = GetAlternateMemoryBenefit(
defining_instruction,
GetInstructionElapsedDueToMemory(defining_instruction,
/*operand_in_alternate_mem=*/{},
/*output_in_alternate_mem=*/true));
/*output_in_alternate_mem=*/true),
cache);
for (const HloUse& use : interval.buffer->uses()) {
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
*use.instruction,
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number));
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
cache);
// If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the
// maximum.
@ -221,9 +237,7 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
float min_async_copy_to_overlap_ratio,
float max_async_copy_to_overlap_ratio)
: elapsed_time_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0),
while_nest_level_(
: while_nest_level_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
cost_analysis_(cost_analysis),
min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
@ -232,19 +246,46 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
&cost_analysis_.hlo_live_range().instruction_schedule();
// Create a vector of elapsed times and while nesting levels of HLO
// instructions.
// instructions. The elapsed times are multiplied by pow(kWhileExecutionCount,
// nest_level) to account for executing the HLOs multiple times in while
// loops.
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
0.0);
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
*instruction_and_logical_time.first);
int64 logical_time = instruction_and_logical_time.second;
if (logical_time >= elapsed_time_.size()) {
elapsed_time_.resize(logical_time + 1, 0.0);
if (logical_time >= instructions_elapsed_time.size()) {
instructions_elapsed_time.resize(logical_time + 1, 0.0);
while_nest_level_.resize(logical_time + 1, 0);
}
elapsed_time_[logical_time] = elapsed_time;
while_nest_level_[logical_time] =
cost_analysis_.CalculateWhileLoopNestLevel(
instruction_and_logical_time.first);
int nest_level = cost_analysis_.CalculateWhileLoopNestLevel(
instruction_and_logical_time.first);
while_nest_level_[logical_time] = nest_level;
instructions_elapsed_time[logical_time] =
elapsed_time *
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, nest_level);
}
// As an optimization, create a cumulative sum vector of elapsed time.
float cumsum = 0.0;
elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
for (float elapsed_time : instructions_elapsed_time) {
cumsum += elapsed_time;
elapsed_time_cumsum_.push_back(cumsum);
}
// To be able to accurately determine the minimum nest level between a start
// time and an end time efficiently, populate a data structure that stores the
// closest nest level change index.
int prev_nest_level = 0;
int change_idx = -1;
while_nest_level_change_.reserve(instructions_elapsed_time.size());
for (int i = 0; i < while_nest_level_.size(); ++i) {
int nest_level = while_nest_level_[i];
if (nest_level != prev_nest_level) {
prev_nest_level = nest_level;
change_idx = i - 1;
}
while_nest_level_change_.push_back(change_idx);
}
}
@ -316,19 +357,32 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
logical_interval_elapsed + inst_elapsed_reduction_;
}
int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
int64 start_time, int64 end_time) const {
int min_nest_level =
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
int change_idx = while_nest_level_change_[end_time];
while (change_idx >= start_time) {
min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
change_idx = while_nest_level_change_[change_idx];
}
return min_nest_level;
}
float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
int64 start_time, int64 end_time) const {
int interval_nest_level =
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
float total_elapsed = 0;
for (int i = start_time + 1; i < end_time; ++i) {
total_elapsed +=
elapsed_time_[i] *
tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount,
std::max(0, while_nest_level_[i] - interval_nest_level));
CHECK_LE(start_time, end_time);
if (start_time == end_time) {
return 0.0;
}
return total_elapsed;
// Since elapsed_time_cumsum_ is already weighed by the while loop nesting
// level, normalize the elapsed time by dividing with the nesting factor of
// the interval (start and end times).
int interval_nest_level = GetMinWhileNestLevel(start_time, end_time);
return (elapsed_time_cumsum_[end_time - 1] -
elapsed_time_cumsum_[start_time]) /
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
interval_nest_level);
}
std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
@ -557,11 +611,12 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
int64 root_time = instruction_schedule.at(while_body->root_instruction());
int64 min_use_time = root_time;
for (const HloUse& parameter_use : parameter_value->uses()) {
int64 use_time = instruction_schedule.at(parameter_use.instruction);
if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
parameter_use.instruction->opcode() != HloOpcode::kTuple &&
parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
min_use_time = std::min(
min_use_time, instruction_schedule.at(parameter_use.instruction));
parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
use_time > parameter_time) {
min_use_time = std::min(min_use_time, use_time);
}
}
// If there is no use of this buffer inside the while loop, there is no need
@ -571,21 +626,13 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
<< "use time = " << min_use_time << ", root time = " << root_time;
return false;
}
HloValue* root_value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
while_body->root_instruction(), use.operand_index);
int64 root_definition_time =
instruction_schedule.at(root_value->defining_instruction());
const Shape& shape = root_value->shape();
const Shape& shape = parameter_value->shape();
// Allow the buffer in alternate memory if the buffer has a short live range
// either at the beginning or end of the while loop body.
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, parameter_time, min_use_time) &&
!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, root_definition_time, root_time)) {
shape, parameter_time, min_use_time)) {
VLOG(4) << "While allocation not allowed in alternate memory. "
<< "use time = " << min_use_time
<< ", def time = " << root_definition_time
<< ", root time = " << root_time;
return false;
}
@ -1890,10 +1937,12 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
/*static*/ MemorySpaceAssignment::BufferIntervalCompare
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
return [&](const BufferInterval& x, const BufferInterval& y) {
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x);
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y);
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
return [cost_analysis, cache](const BufferInterval& x,
const BufferInterval& y) {
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
if (x_memory_boundedness != y_memory_boundedness) {
return x_memory_boundedness > y_memory_boundedness;
}

View File

@ -78,6 +78,12 @@ class PresetAssignments {
// bandwidths of different memory spaces.
class MemorySpaceAssignmentCostAnalysis {
public:
// An optional Cache object may be provided to some of the methods below to
// speed up the lookup.
struct Cache {
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
};
MemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
@ -97,15 +103,16 @@ class MemorySpaceAssignmentCostAnalysis {
// alternate memory would help if the op is memory bound, or otherwise how far
// off is the op to memory boundedness. The larger this number, the higher
// priority it will be placed in the alternate memory.
float GetAlternateMemoryBenefit(
const HloInstruction& instruction,
float elapsed_time_due_to_alternate_mem) const;
float GetAlternateMemoryBenefit(const HloInstruction& instruction,
float elapsed_time_due_to_alternate_mem,
Cache* cache = nullptr) const;
// Returns a heuristic value of memory boundedness for the given
// BufferInterval. The larger this number, the higher priority it will be
// placed in the alternate memory.
float GetMemoryBoundedness(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const;
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
Cache* cache = nullptr) const;
// Returns the elapsed time in seconds due to compute only.
float GetInstructionElapsedDueToCompute(
@ -282,10 +289,17 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
// corresponds to the instruction schedule.
float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
// Finds the minimum nest level in the given interval.
int GetMinWhileNestLevel(int64 start_time, int64 end_time) const;
// For each instruction in the flattened schedule, maintain their elapsed time
// and while nesting level.
std::vector<float> elapsed_time_;
// (in cumulative sum) and while nesting level.
std::vector<float> elapsed_time_cumsum_;
std::vector<int> while_nest_level_;
// Maintain the index of the most recent (before this instruction) nest level
// change in order to efficiently determine the minimum nest level in an
// interval.
std::vector<int> while_nest_level_change_;
const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
float min_async_copy_to_overlap_ratio_;
@ -645,7 +659,8 @@ class MemorySpaceAssignment {
StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis);
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
// Verify that the memory space assignment is free of overlapping buffers and
// export heap simulator trace to be used by buffer_assignment.

View File

@ -68,7 +68,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
return AssignMemorySpace(
module, /*max_outstanding_async_copies=*/-1,
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
cost_analysis),
cost_analysis, &cache_),
&prefetch_interval_picker);
}
@ -285,6 +285,8 @@ class MemorySpaceAssignmentTest : public HloTestBase,
TF_CHECK_OK(module->set_schedule(schedule));
return module;
}
MemorySpaceAssignmentCostAnalysis::Cache cache_;
};
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {