[XLA] Improve compile time performance of memory space assignment.
PiperOrigin-RevId: 314580604 Change-Id: I08ce28701eee0513bfea067662df3e2ce57d9987
This commit is contained in:
parent
929398ef01
commit
50c3196789
|
@ -31,18 +31,31 @@ const int kWhileExecutionCount = 5;
|
|||
} // namespace
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
||||
const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem) const {
|
||||
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
||||
float elapsed_time_due_to_compute =
|
||||
GetInstructionElapsedDueToCompute(instruction);
|
||||
float elapsed_time_due_to_memory =
|
||||
GetInstructionElapsedDueToMemory(instruction);
|
||||
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
|
||||
// Memory bound, return how much alternate memory is better.
|
||||
int while_nest_level = CalculateWhileLoopNestLevel(&instruction);
|
||||
float while_nest_multiplier;
|
||||
if (cache) {
|
||||
// If there is a cache provided, memoize the while nest multiplier.
|
||||
auto it = cache->while_nest_multiplier.find(&instruction);
|
||||
if (it != cache->while_nest_multiplier.end()) {
|
||||
while_nest_multiplier = it->second;
|
||||
} else {
|
||||
while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
|
||||
kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction));
|
||||
cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
|
||||
}
|
||||
} else {
|
||||
while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
|
||||
kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction));
|
||||
}
|
||||
return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
|
||||
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
|
||||
while_nest_level);
|
||||
while_nest_multiplier;
|
||||
} else {
|
||||
// Compute bound, return how far off are we to memory boundedness.
|
||||
return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
|
||||
|
@ -50,18 +63,21 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
|||
}
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
||||
const HloInstruction& defining_instruction =
|
||||
*interval.buffer->defining_instruction();
|
||||
float alternate_mem_benefit = GetAlternateMemoryBenefit(
|
||||
defining_instruction,
|
||||
GetInstructionElapsedDueToMemory(defining_instruction,
|
||||
/*operand_in_alternate_mem=*/{},
|
||||
/*output_in_alternate_mem=*/true));
|
||||
/*output_in_alternate_mem=*/true),
|
||||
cache);
|
||||
for (const HloUse& use : interval.buffer->uses()) {
|
||||
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
|
||||
*use.instruction,
|
||||
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number));
|
||||
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
|
||||
cache);
|
||||
// If the benefit is positive (memory bound), add it to this buffer's
|
||||
// benefit. If the benefit is negative (compute bound), calculate the
|
||||
// maximum.
|
||||
|
@ -221,9 +237,7 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
|
|||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
float min_async_copy_to_overlap_ratio,
|
||||
float max_async_copy_to_overlap_ratio)
|
||||
: elapsed_time_(
|
||||
cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0),
|
||||
while_nest_level_(
|
||||
: while_nest_level_(
|
||||
cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
|
||||
cost_analysis_(cost_analysis),
|
||||
min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
|
||||
|
@ -232,19 +246,46 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
|
|||
&cost_analysis_.hlo_live_range().instruction_schedule();
|
||||
|
||||
// Create a vector of elapsed times and while nesting levels of HLO
|
||||
// instructions.
|
||||
// instructions. The elapsed times are multiplied by pow(kWhileExecutionCount,
|
||||
// nest_level) to account for executing the HLOs multiple times in while
|
||||
// loops.
|
||||
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
|
||||
0.0);
|
||||
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
|
||||
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
|
||||
*instruction_and_logical_time.first);
|
||||
int64 logical_time = instruction_and_logical_time.second;
|
||||
if (logical_time >= elapsed_time_.size()) {
|
||||
elapsed_time_.resize(logical_time + 1, 0.0);
|
||||
if (logical_time >= instructions_elapsed_time.size()) {
|
||||
instructions_elapsed_time.resize(logical_time + 1, 0.0);
|
||||
while_nest_level_.resize(logical_time + 1, 0);
|
||||
}
|
||||
elapsed_time_[logical_time] = elapsed_time;
|
||||
while_nest_level_[logical_time] =
|
||||
cost_analysis_.CalculateWhileLoopNestLevel(
|
||||
int nest_level = cost_analysis_.CalculateWhileLoopNestLevel(
|
||||
instruction_and_logical_time.first);
|
||||
while_nest_level_[logical_time] = nest_level;
|
||||
instructions_elapsed_time[logical_time] =
|
||||
elapsed_time *
|
||||
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, nest_level);
|
||||
}
|
||||
// As an optimization, create a cumulative sum vector of elapsed time.
|
||||
float cumsum = 0.0;
|
||||
elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
|
||||
for (float elapsed_time : instructions_elapsed_time) {
|
||||
cumsum += elapsed_time;
|
||||
elapsed_time_cumsum_.push_back(cumsum);
|
||||
}
|
||||
// To be able to accurately determine the minimum nest level between a start
|
||||
// time and an end time efficiently, populate a data structure that stores the
|
||||
// closest nest level change index.
|
||||
int prev_nest_level = 0;
|
||||
int change_idx = -1;
|
||||
while_nest_level_change_.reserve(instructions_elapsed_time.size());
|
||||
for (int i = 0; i < while_nest_level_.size(); ++i) {
|
||||
int nest_level = while_nest_level_[i];
|
||||
if (nest_level != prev_nest_level) {
|
||||
prev_nest_level = nest_level;
|
||||
change_idx = i - 1;
|
||||
}
|
||||
while_nest_level_change_.push_back(change_idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -316,19 +357,32 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
|
|||
logical_interval_elapsed + inst_elapsed_reduction_;
|
||||
}
|
||||
|
||||
int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
|
||||
int64 start_time, int64 end_time) const {
|
||||
int min_nest_level =
|
||||
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
|
||||
int change_idx = while_nest_level_change_[end_time];
|
||||
while (change_idx >= start_time) {
|
||||
min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
|
||||
change_idx = while_nest_level_change_[change_idx];
|
||||
}
|
||||
return min_nest_level;
|
||||
}
|
||||
|
||||
float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
|
||||
int64 start_time, int64 end_time) const {
|
||||
int interval_nest_level =
|
||||
std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
|
||||
float total_elapsed = 0;
|
||||
for (int i = start_time + 1; i < end_time; ++i) {
|
||||
total_elapsed +=
|
||||
elapsed_time_[i] *
|
||||
tensorflow::MathUtil::IPow<float>(
|
||||
kWhileExecutionCount,
|
||||
std::max(0, while_nest_level_[i] - interval_nest_level));
|
||||
CHECK_LE(start_time, end_time);
|
||||
if (start_time == end_time) {
|
||||
return 0.0;
|
||||
}
|
||||
return total_elapsed;
|
||||
// Since elapsed_time_cumsum_ is already weighed by the while loop nesting
|
||||
// level, normalize the elapsed time by dividing with the nesting factor of
|
||||
// the interval (start and end times).
|
||||
int interval_nest_level = GetMinWhileNestLevel(start_time, end_time);
|
||||
return (elapsed_time_cumsum_[end_time - 1] -
|
||||
elapsed_time_cumsum_[start_time]) /
|
||||
tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
|
||||
interval_nest_level);
|
||||
}
|
||||
|
||||
std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
|
||||
|
@ -557,11 +611,12 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
|||
int64 root_time = instruction_schedule.at(while_body->root_instruction());
|
||||
int64 min_use_time = root_time;
|
||||
for (const HloUse& parameter_use : parameter_value->uses()) {
|
||||
int64 use_time = instruction_schedule.at(parameter_use.instruction);
|
||||
if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
|
||||
parameter_use.instruction->opcode() != HloOpcode::kTuple &&
|
||||
parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||
min_use_time = std::min(
|
||||
min_use_time, instruction_schedule.at(parameter_use.instruction));
|
||||
parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
|
||||
use_time > parameter_time) {
|
||||
min_use_time = std::min(min_use_time, use_time);
|
||||
}
|
||||
}
|
||||
// If there is no use of this buffer inside the while loop, there is no need
|
||||
|
@ -571,21 +626,13 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
|||
<< "use time = " << min_use_time << ", root time = " << root_time;
|
||||
return false;
|
||||
}
|
||||
HloValue* root_value =
|
||||
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
|
||||
while_body->root_instruction(), use.operand_index);
|
||||
int64 root_definition_time =
|
||||
instruction_schedule.at(root_value->defining_instruction());
|
||||
const Shape& shape = root_value->shape();
|
||||
const Shape& shape = parameter_value->shape();
|
||||
// Allow the buffer in alternate memory if the buffer has a short live range
|
||||
// either at the beginning or end of the while loop body.
|
||||
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
|
||||
shape, parameter_time, min_use_time) &&
|
||||
!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
|
||||
shape, root_definition_time, root_time)) {
|
||||
shape, parameter_time, min_use_time)) {
|
||||
VLOG(4) << "While allocation not allowed in alternate memory. "
|
||||
<< "use time = " << min_use_time
|
||||
<< ", def time = " << root_definition_time
|
||||
<< ", root time = " << root_time;
|
||||
return false;
|
||||
}
|
||||
|
@ -1890,10 +1937,12 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
|
|||
|
||||
/*static*/ MemorySpaceAssignment::BufferIntervalCompare
|
||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
|
||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x);
|
||||
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y);
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
|
||||
return [cost_analysis, cache](const BufferInterval& x,
|
||||
const BufferInterval& y) {
|
||||
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
|
||||
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
|
||||
if (x_memory_boundedness != y_memory_boundedness) {
|
||||
return x_memory_boundedness > y_memory_boundedness;
|
||||
}
|
||||
|
|
|
@ -78,6 +78,12 @@ class PresetAssignments {
|
|||
// bandwidths of different memory spaces.
|
||||
class MemorySpaceAssignmentCostAnalysis {
|
||||
public:
|
||||
// An optional Cache object may be provided to some of the methods below to
|
||||
// speed up the lookup.
|
||||
struct Cache {
|
||||
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
|
||||
};
|
||||
|
||||
MemorySpaceAssignmentCostAnalysis(
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
|
@ -97,15 +103,16 @@ class MemorySpaceAssignmentCostAnalysis {
|
|||
// alternate memory would help if the op is memory bound, or otherwise how far
|
||||
// off is the op to memory boundedness. The larger this number, the higher
|
||||
// priority it will be placed in the alternate memory.
|
||||
float GetAlternateMemoryBenefit(
|
||||
const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem) const;
|
||||
float GetAlternateMemoryBenefit(const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem,
|
||||
Cache* cache = nullptr) const;
|
||||
|
||||
// Returns a heuristic value of memory boundedness for the given
|
||||
// BufferInterval. The larger this number, the higher priority it will be
|
||||
// placed in the alternate memory.
|
||||
float GetMemoryBoundedness(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const;
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
Cache* cache = nullptr) const;
|
||||
|
||||
// Returns the elapsed time in seconds due to compute only.
|
||||
float GetInstructionElapsedDueToCompute(
|
||||
|
@ -282,10 +289,17 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
|||
// corresponds to the instruction schedule.
|
||||
float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
|
||||
|
||||
// Finds the minimum nest level in the given interval.
|
||||
int GetMinWhileNestLevel(int64 start_time, int64 end_time) const;
|
||||
|
||||
// For each instruction in the flattened schedule, maintain their elapsed time
|
||||
// and while nesting level.
|
||||
std::vector<float> elapsed_time_;
|
||||
// (in cumulative sum) and while nesting level.
|
||||
std::vector<float> elapsed_time_cumsum_;
|
||||
std::vector<int> while_nest_level_;
|
||||
// Maintain the index of the most recent (before this instruction) nest level
|
||||
// change in order to efficiently determine the minimum nest level in an
|
||||
// interval.
|
||||
std::vector<int> while_nest_level_change_;
|
||||
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
|
||||
float min_async_copy_to_overlap_ratio_;
|
||||
|
@ -645,7 +659,8 @@ class MemorySpaceAssignment {
|
|||
StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
|
||||
|
||||
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis);
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
|
||||
|
||||
// Verify that the memory space assignment is free of overlapping buffers and
|
||||
// export heap simulator trace to be used by buffer_assignment.
|
||||
|
|
|
@ -68,7 +68,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||
return AssignMemorySpace(
|
||||
module, /*max_outstanding_async_copies=*/-1,
|
||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
cost_analysis),
|
||||
cost_analysis, &cache_),
|
||||
&prefetch_interval_picker);
|
||||
}
|
||||
|
||||
|
@ -285,6 +285,8 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
return module;
|
||||
}
|
||||
|
||||
MemorySpaceAssignmentCostAnalysis::Cache cache_;
|
||||
};
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
|
||||
|
|
Loading…
Reference in New Issue