[XLA] Improve compile time performance of memory space assignment.

PiperOrigin-RevId: 314580604
Change-Id: I08ce28701eee0513bfea067662df3e2ce57d9987
This commit is contained in:
Berkin Ilbeyi 2020-06-03 12:21:52 -07:00 committed by TensorFlower Gardener
parent 929398ef01
commit 50c3196789
3 changed files with 119 additions and 53 deletions

View File

@ -31,18 +31,31 @@ const int kWhileExecutionCount = 5;
} // namespace } // namespace
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
const HloInstruction& instruction, const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
float elapsed_time_due_to_alternate_mem) const { MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
float elapsed_time_due_to_compute = float elapsed_time_due_to_compute =
GetInstructionElapsedDueToCompute(instruction); GetInstructionElapsedDueToCompute(instruction);
float elapsed_time_due_to_memory = float elapsed_time_due_to_memory =
GetInstructionElapsedDueToMemory(instruction); GetInstructionElapsedDueToMemory(instruction);
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
// Memory bound, return how much alternate memory is better. // Memory bound, return how much alternate memory is better.
int while_nest_level = CalculateWhileLoopNestLevel(&instruction); float while_nest_multiplier;
if (cache) {
// If there is a cache provided, memoize the while nest multiplier.
auto it = cache->while_nest_multiplier.find(&instruction);
if (it != cache->while_nest_multiplier.end()) {
while_nest_multiplier = it->second;
} else {
while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction));
cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
}
} else {
while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction));
}
return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, while_nest_multiplier;
while_nest_level);
} else { } else {
// Compute bound, return how far off are we to memory boundedness. // Compute bound, return how far off are we to memory boundedness.
return elapsed_time_due_to_memory - elapsed_time_due_to_compute; return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
@ -50,18 +63,21 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
} }
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
const HloInstruction& defining_instruction = const HloInstruction& defining_instruction =
*interval.buffer->defining_instruction(); *interval.buffer->defining_instruction();
float alternate_mem_benefit = GetAlternateMemoryBenefit( float alternate_mem_benefit = GetAlternateMemoryBenefit(
defining_instruction, defining_instruction,
GetInstructionElapsedDueToMemory(defining_instruction, GetInstructionElapsedDueToMemory(defining_instruction,
/*operand_in_alternate_mem=*/{}, /*operand_in_alternate_mem=*/{},
/*output_in_alternate_mem=*/true)); /*output_in_alternate_mem=*/true),
cache);
for (const HloUse& use : interval.buffer->uses()) { for (const HloUse& use : interval.buffer->uses()) {
float use_alternate_mem_benefit = GetAlternateMemoryBenefit( float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
*use.instruction, *use.instruction,
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number)); GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
cache);
// If the benefit is positive (memory bound), add it to this buffer's // If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the // benefit. If the benefit is negative (compute bound), calculate the
// maximum. // maximum.
@ -221,9 +237,7 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
const MemorySpaceAssignmentCostAnalysis& cost_analysis, const MemorySpaceAssignmentCostAnalysis& cost_analysis,
float min_async_copy_to_overlap_ratio, float min_async_copy_to_overlap_ratio,
float max_async_copy_to_overlap_ratio) float max_async_copy_to_overlap_ratio)
: elapsed_time_( : while_nest_level_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0),
while_nest_level_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0), cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
cost_analysis_(cost_analysis), cost_analysis_(cost_analysis),
min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
@ -232,19 +246,46 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
&cost_analysis_.hlo_live_range().instruction_schedule(); &cost_analysis_.hlo_live_range().instruction_schedule();
// Create a vector of elapsed times and while nesting levels of HLO // Create a vector of elapsed times and while nesting levels of HLO
// instructions. // instructions. The elapsed times are multiplied by pow(kWhileExecutionCount,
// nest_level) to account for executing the HLOs multiple times in while
// loops.
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
0.0);
for (const auto& instruction_and_logical_time : *instruction_schedule_) { for (const auto& instruction_and_logical_time : *instruction_schedule_) {
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
*instruction_and_logical_time.first); *instruction_and_logical_time.first);
int64 logical_time = instruction_and_logical_time.second; int64 logical_time = instruction_and_logical_time.second;
if (logical_time >= elapsed_time_.size()) { if (logical_time >= instructions_elapsed_time.size()) {
elapsed_time_.resize(logical_time + 1, 0.0); instructions_elapsed_time.resize(logical_time + 1, 0.0);
while_nest_level_.resize(logical_time + 1, 0); while_nest_level_.resize(logical_time + 1, 0);
} }
elapsed_time_[logical_time] = elapsed_time; int nest_level = cost_analysis_.CalculateWhileLoopNestLevel(
while_nest_level_[logical_time] = instruction_and_logical_time.first);
cost_analysis_.CalculateWhileLoopNestLevel( while_nest_level_[logical_time] = nest_level;
instruction_and_logical_time.first); instructions_elapsed_time[logical_time] =
elapsed_time *
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, nest_level);
}
// As an optimization, create a cumulative sum vector of elapsed time.
float cumsum = 0.0;
elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
for (float elapsed_time : instructions_elapsed_time) {
cumsum += elapsed_time;
elapsed_time_cumsum_.push_back(cumsum);
}
// To be able to accurately determine the minimum nest level between a start
// time and an end time efficiently, populate a data structure that stores the
// closest nest level change index.
int prev_nest_level = 0;
int change_idx = -1;
while_nest_level_change_.reserve(instructions_elapsed_time.size());
for (int i = 0; i < while_nest_level_.size(); ++i) {
int nest_level = while_nest_level_[i];
if (nest_level != prev_nest_level) {
prev_nest_level = nest_level;
change_idx = i - 1;
}
while_nest_level_change_.push_back(change_idx);
} }
} }
@ -316,19 +357,32 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
logical_interval_elapsed + inst_elapsed_reduction_; logical_interval_elapsed + inst_elapsed_reduction_;
} }
int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
int64 start_time, int64 end_time) const {
int min_nest_level =
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
int change_idx = while_nest_level_change_[end_time];
while (change_idx >= start_time) {
min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
change_idx = while_nest_level_change_[change_idx];
}
return min_nest_level;
}
float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
int64 start_time, int64 end_time) const { int64 start_time, int64 end_time) const {
int interval_nest_level = CHECK_LE(start_time, end_time);
std::min(while_nest_level_[start_time], while_nest_level_[end_time]); if (start_time == end_time) {
float total_elapsed = 0; return 0.0;
for (int i = start_time + 1; i < end_time; ++i) {
total_elapsed +=
elapsed_time_[i] *
tensorflow::MathUtil::IPow<float>(
kWhileExecutionCount,
std::max(0, while_nest_level_[i] - interval_nest_level));
} }
return total_elapsed; // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
// level, normalize the elapsed time by dividing with the nesting factor of
// the interval (start and end times).
int interval_nest_level = GetMinWhileNestLevel(start_time, end_time);
return (elapsed_time_cumsum_[end_time - 1] -
elapsed_time_cumsum_[start_time]) /
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
interval_nest_level);
} }
std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
@ -557,11 +611,12 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
int64 root_time = instruction_schedule.at(while_body->root_instruction()); int64 root_time = instruction_schedule.at(while_body->root_instruction());
int64 min_use_time = root_time; int64 min_use_time = root_time;
for (const HloUse& parameter_use : parameter_value->uses()) { for (const HloUse& parameter_use : parameter_value->uses()) {
int64 use_time = instruction_schedule.at(parameter_use.instruction);
if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement && if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
parameter_use.instruction->opcode() != HloOpcode::kTuple && parameter_use.instruction->opcode() != HloOpcode::kTuple &&
parameter_use.instruction->opcode() != HloOpcode::kBitcast) { parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
min_use_time = std::min( use_time > parameter_time) {
min_use_time, instruction_schedule.at(parameter_use.instruction)); min_use_time = std::min(min_use_time, use_time);
} }
} }
// If there is no use of this buffer inside the while loop, there is no need // If there is no use of this buffer inside the while loop, there is no need
@ -571,21 +626,13 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
<< "use time = " << min_use_time << ", root time = " << root_time; << "use time = " << min_use_time << ", root time = " << root_time;
return false; return false;
} }
HloValue* root_value = const Shape& shape = parameter_value->shape();
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
while_body->root_instruction(), use.operand_index);
int64 root_definition_time =
instruction_schedule.at(root_value->defining_instruction());
const Shape& shape = root_value->shape();
// Allow the buffer in alternate memory if the buffer has a short live range // Allow the buffer in alternate memory if the buffer has a short live range
// either at the beginning or end of the while loop body. // either at the beginning or end of the while loop body.
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, parameter_time, min_use_time) && shape, parameter_time, min_use_time)) {
!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, root_definition_time, root_time)) {
VLOG(4) << "While allocation not allowed in alternate memory. " VLOG(4) << "While allocation not allowed in alternate memory. "
<< "use time = " << min_use_time << "use time = " << min_use_time
<< ", def time = " << root_definition_time
<< ", root time = " << root_time; << ", root time = " << root_time;
return false; return false;
} }
@ -1890,10 +1937,12 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
/*static*/ MemorySpaceAssignment::BufferIntervalCompare /*static*/ MemorySpaceAssignment::BufferIntervalCompare
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis) { const MemorySpaceAssignmentCostAnalysis& cost_analysis,
return [&](const BufferInterval& x, const BufferInterval& y) { MemorySpaceAssignmentCostAnalysis::Cache* cache) {
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x); return [cost_analysis, cache](const BufferInterval& x,
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y); const BufferInterval& y) {
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
if (x_memory_boundedness != y_memory_boundedness) { if (x_memory_boundedness != y_memory_boundedness) {
return x_memory_boundedness > y_memory_boundedness; return x_memory_boundedness > y_memory_boundedness;
} }

View File

@ -78,6 +78,12 @@ class PresetAssignments {
// bandwidths of different memory spaces. // bandwidths of different memory spaces.
class MemorySpaceAssignmentCostAnalysis { class MemorySpaceAssignmentCostAnalysis {
public: public:
// An optional Cache object may be provided to some of the methods below to
// speed up the lookup.
struct Cache {
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
};
MemorySpaceAssignmentCostAnalysis( MemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis, const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second, float async_copy_bandwidth_bytes_per_second,
@ -97,15 +103,16 @@ class MemorySpaceAssignmentCostAnalysis {
// alternate memory would help if the op is memory bound, or otherwise how far // alternate memory would help if the op is memory bound, or otherwise how far
// off is the op to memory boundedness. The larger this number, the higher // off is the op to memory boundedness. The larger this number, the higher
// priority it will be placed in the alternate memory. // priority it will be placed in the alternate memory.
float GetAlternateMemoryBenefit( float GetAlternateMemoryBenefit(const HloInstruction& instruction,
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
float elapsed_time_due_to_alternate_mem) const; Cache* cache = nullptr) const;
// Returns a heuristic value of memory boundedness for the given // Returns a heuristic value of memory boundedness for the given
// BufferInterval. The larger this number, the higher priority it will be // BufferInterval. The larger this number, the higher priority it will be
// placed in the alternate memory. // placed in the alternate memory.
float GetMemoryBoundedness( float GetMemoryBoundedness(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const; const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
Cache* cache = nullptr) const;
// Returns the elapsed time in seconds due to compute only. // Returns the elapsed time in seconds due to compute only.
float GetInstructionElapsedDueToCompute( float GetInstructionElapsedDueToCompute(
@ -282,10 +289,17 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
// corresponds to the instruction schedule. // corresponds to the instruction schedule.
float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
// Finds the minimum nest level in the given interval.
int GetMinWhileNestLevel(int64 start_time, int64 end_time) const;
// For each instruction in the flattened schedule, maintain their elapsed time // For each instruction in the flattened schedule, maintain their elapsed time
// and while nesting level. // (in cumulative sum) and while nesting level.
std::vector<float> elapsed_time_; std::vector<float> elapsed_time_cumsum_;
std::vector<int> while_nest_level_; std::vector<int> while_nest_level_;
// Maintain the index of the most recent (before this instruction) nest level
// change in order to efficiently determine the minimum nest level in an
// interval.
std::vector<int> while_nest_level_change_;
const MemorySpaceAssignmentCostAnalysis& cost_analysis_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
float min_async_copy_to_overlap_ratio_; float min_async_copy_to_overlap_ratio_;
@ -645,7 +659,8 @@ class MemorySpaceAssignment {
StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const; StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis); const MemorySpaceAssignmentCostAnalysis& cost_analysis,
MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
// Verify that the memory space assignment is free of overlapping buffers and // Verify that the memory space assignment is free of overlapping buffers and
// export heap simulator trace to be used by buffer_assignment. // export heap simulator trace to be used by buffer_assignment.

View File

@ -68,7 +68,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
return AssignMemorySpace( return AssignMemorySpace(
module, /*max_outstanding_async_copies=*/-1, module, /*max_outstanding_async_copies=*/-1,
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
cost_analysis), cost_analysis, &cache_),
&prefetch_interval_picker); &prefetch_interval_picker);
} }
@ -285,6 +285,8 @@ class MemorySpaceAssignmentTest : public HloTestBase,
TF_CHECK_OK(module->set_schedule(schedule)); TF_CHECK_OK(module->set_schedule(schedule));
return module; return module;
} }
MemorySpaceAssignmentCostAnalysis::Cache cache_;
}; };
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) { TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {