[XLA] NFC: Refactor to split a very large method.
PiperOrigin-RevId: 314738045 Change-Id: I07ac601281ff03f0209727fef446568d1df576b2
This commit is contained in:
parent
3897e02fc9
commit
47f7695c00
|
@ -787,14 +787,12 @@ void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
|
|||
}
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::DumpIfEnabled(
|
||||
absl::string_view buffer_info_str,
|
||||
absl::string_view allocation_info_str) const {
|
||||
void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
|
||||
if (!options_.dump_fn) {
|
||||
return;
|
||||
}
|
||||
options_.dump_fn("bufferinfo", buffer_info_str);
|
||||
options_.dump_fn("allocinfo", allocation_info_str);
|
||||
options_.dump_fn("bufferinfo", buffer_info_str_);
|
||||
options_.dump_fn("allocinfo", allocation_info_str_);
|
||||
}
|
||||
|
||||
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
|
@ -816,9 +814,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||
}
|
||||
}
|
||||
|
||||
std::string buffer_info_str;
|
||||
std::string allocation_info_str;
|
||||
|
||||
for (auto& interval : sorted_buffer_intervals) {
|
||||
if (!interval.need_allocation) {
|
||||
continue;
|
||||
|
@ -842,12 +837,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||
}
|
||||
|
||||
auto colocated_intervals = GetSortedColocatedIntervals(interval);
|
||||
// Create AllocationValues for all the
|
||||
// colocated intervals.
|
||||
std::vector<AllocationValue> allocation_values;
|
||||
for (const auto& colocated_interval : colocated_intervals) {
|
||||
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
||||
}
|
||||
|
||||
if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
|
||||
VLOG(3) << "Interval " << interval.buffer->ToShortString()
|
||||
|
@ -890,8 +879,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
|
||||
// TODO(berkin): For now, place the phi values due to conditionals in
|
||||
// default memory.
|
||||
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
||||
|
@ -911,192 +898,203 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||
}
|
||||
}
|
||||
|
||||
AppendBufferInfoDebugString(interval, &buffer_info_str);
|
||||
|
||||
// Data structure to contain the preferred offset for a given computation.
|
||||
// We ensure that the same offset will be allocated outside the while loop
|
||||
// as well as inside the while loop.
|
||||
absl::flat_hash_map<const HloComputation*, int64>
|
||||
preferred_offset_for_computation;
|
||||
bool allocation_success = true;
|
||||
for (auto& allocation_value : allocation_values) {
|
||||
int64 definition_time =
|
||||
instruction_schedule.at(allocation_value.defining_instruction());
|
||||
|
||||
absl::optional<int64> preferred_offset;
|
||||
auto preferred_offset_it =
|
||||
preferred_offset_for_computation.find(allocation_value.computation());
|
||||
if (preferred_offset_it != preferred_offset_for_computation.end()) {
|
||||
preferred_offset = preferred_offset_it->second;
|
||||
}
|
||||
|
||||
// Iterate over the uses.
|
||||
for (int use_idx = 0; use_idx < allocation_value.uses().size();
|
||||
++use_idx) {
|
||||
const HloUse& use = allocation_value.uses().at(use_idx);
|
||||
int64 use_time = instruction_schedule.at(use.instruction);
|
||||
int64 latest_prefetch_time = use_time;
|
||||
bool allow_no_copy_alternate_mem_allocation = true;
|
||||
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
|
||||
|
||||
// Sequential calls include kWhile, kCall, and kConditional opcodes.
|
||||
bool is_sequential_call =
|
||||
(GetInstructionCallContext(use.instruction->opcode()) ==
|
||||
CallContext::kSequential);
|
||||
if (is_sequential_call) {
|
||||
for (const HloComputation* called_computation :
|
||||
use.instruction->called_computations()) {
|
||||
const HloLiveRange::TimeBound& computation_span =
|
||||
hlo_live_range_.computation_span_times().at(called_computation);
|
||||
latest_prefetch_time =
|
||||
std::min(computation_span.start, latest_prefetch_time);
|
||||
}
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Given an example while loop and flattened schedule (logical times
|
||||
// shown on the left):
|
||||
//
|
||||
// 0: a = ...
|
||||
// 1: ...
|
||||
// cond {
|
||||
// 2: p = param(0)
|
||||
// 3: ...
|
||||
// }
|
||||
// body {
|
||||
// 4: p = param(0)
|
||||
// 5: ...
|
||||
// 6: ROOT ...
|
||||
// }
|
||||
// 7: w = while(a), body=body, cond=cond
|
||||
//
|
||||
// When processing "a" (time 0) and its while use (time 7), we
|
||||
// update the interval to time 0-4. This is so that the remaining
|
||||
// interval (5-6) can be allocated separately and this buffer
|
||||
// doesn't waste alternate memory space within the while loop body.
|
||||
HloComputation* while_body = use.instruction->while_body();
|
||||
// We require while body ROOTs to be the last in the schedule.
|
||||
CHECK_EQ(
|
||||
instruction_schedule.at(while_body->root_instruction()) + 1,
|
||||
instruction_schedule.at(use.instruction))
|
||||
<< "While body ROOTs need to be the last in the schedule! "
|
||||
"Please run RootInstructionSinker.";
|
||||
// Replace the use time with the parameter time so that we can
|
||||
// decide on alternate memory allocations within the while loop body
|
||||
// when we look at uses within the while loop body.
|
||||
use_time =
|
||||
instruction_schedule.at(while_body->parameter_instruction(0));
|
||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
// Replace the use time with the earliest parameter of called
|
||||
// computations.
|
||||
for (const HloComputation* called_computation :
|
||||
use.instruction->called_computations()) {
|
||||
use_time = std::min(
|
||||
use_time, instruction_schedule.at(
|
||||
called_computation->parameter_instruction(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a required assignment in default memory if the use not allowed in
|
||||
// alternate memory.
|
||||
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
|
||||
AddRequiredAssignment(allocation_value.value(), use.instruction,
|
||||
MemorySpace::kDefault, use_time);
|
||||
} else if (use_idx > 0) {
|
||||
// We allow buffers in alternate memory that are passed into
|
||||
// conditionals to give up their alternate memory allocation inside
|
||||
// the called computation. This means that if a conditional operator
|
||||
// has an alternate memory allocation, subsequent uses cannot use the
|
||||
// same alternate memory allocation in order not to clobber data. So
|
||||
// we force default memory allocation for these subsequent uses.
|
||||
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
|
||||
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
|
||||
previous_use.instruction != use.instruction) {
|
||||
allow_no_copy_alternate_mem_allocation = false;
|
||||
earliest_prefetch_time =
|
||||
instruction_schedule.at(previous_use.instruction);
|
||||
VLOG(3) << "Previous use (" << previous_use.ToString()
|
||||
<< ") of use (" << use.ToString()
|
||||
<< ") is a conditional, so this use will need to evict. "
|
||||
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
||||
}
|
||||
}
|
||||
|
||||
// Bitcasts don't define buffers and don't directly consume buffers.
|
||||
// Skip allocating buffers for bitcast uses. The uses that feed from
|
||||
// bitcasts will be handled specially.
|
||||
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||
AllocationRequest request;
|
||||
// Rarely, (e.g., when conditional true and false parameters are the
|
||||
// same), definition time can be the time of the conditional and use
|
||||
// time is the parameter use, which is less.
|
||||
request.start_time = std::min(definition_time, use_time);
|
||||
request.end_time = use_time;
|
||||
request.latest_prefetch_time = latest_prefetch_time;
|
||||
request.size = interval.size;
|
||||
request.allow_no_copy_alternate_mem_allocation =
|
||||
allow_no_copy_alternate_mem_allocation;
|
||||
request.earliest_prefetch_time = earliest_prefetch_time;
|
||||
request.preferred_offset = preferred_offset;
|
||||
request.use = use;
|
||||
request.allocation_value = &allocation_value;
|
||||
if (!FindAllocation(request)) {
|
||||
// If the allocation finding failed (e.g., due to running out of
|
||||
// asynchronous copies), then fall back to allocating the buffer
|
||||
// entirely in the default memory.
|
||||
UncommitPendingChunks();
|
||||
allocation_success = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are multiple uses, they can try using the memory
|
||||
// allocation already at the alternate memory.
|
||||
definition_time = instruction_schedule.at(use.instruction);
|
||||
}
|
||||
|
||||
// If the use has been a sequential call (e.g. a while loop), the other
|
||||
// colocated intervals must alias with this allocation.
|
||||
if (is_sequential_call) {
|
||||
MemorySpaceAssignment::Allocation* aliased_allocation =
|
||||
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
||||
use_time);
|
||||
AddAliasedRequiredAssignmentsForSequentialCall(use,
|
||||
aliased_allocation);
|
||||
// Remember the preferred offset to be used inside while loop body
|
||||
// computations.
|
||||
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
|
||||
use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
preferred_offset_for_computation[use.instruction->while_body()] =
|
||||
aliased_allocation->chunk().offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!allocation_success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allocation_success) {
|
||||
for (AllocationValue& allocation_value : allocation_values) {
|
||||
for (auto& allocation : *allocation_value.allocation_sequence()) {
|
||||
AppendAllocationInfoDebugString(interval, *allocation,
|
||||
&allocation_info_str);
|
||||
allocations_->push_back(std::move(allocation));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pending_chunks_.clear();
|
||||
pending_async_copies_.clear();
|
||||
AllocateColocatedIntervals(colocated_intervals);
|
||||
}
|
||||
|
||||
VLOG(3) << "Debug buffer info: ";
|
||||
VLOG(3) << buffer_info_str;
|
||||
VLOG(3) << buffer_info_str_;
|
||||
VLOG(3) << "Debug allocation info: ";
|
||||
VLOG(3) << allocation_info_str;
|
||||
DumpIfEnabled(buffer_info_str, allocation_info_str);
|
||||
VLOG(3) << allocation_info_str_;
|
||||
DumpDebugStringsIfEnabled();
|
||||
|
||||
return result_;
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>&
|
||||
colocated_intervals) {
|
||||
// Create AllocationValues for all the colocated intervals.
|
||||
std::vector<AllocationValue> allocation_values;
|
||||
for (const auto& colocated_interval : colocated_intervals) {
|
||||
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
||||
}
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
|
||||
// Data structure to contain the preferred offset for a given computation.
|
||||
// We ensure that the same offset will be allocated outside the while loop
|
||||
// as well as inside the while loop.
|
||||
absl::flat_hash_map<const HloComputation*, int64>
|
||||
preferred_offset_for_computation;
|
||||
|
||||
AppendBufferInfoDebugString(*colocated_intervals[0], &buffer_info_str_);
|
||||
|
||||
bool allocation_success = true;
|
||||
for (auto& allocation_value : allocation_values) {
|
||||
int64 definition_time =
|
||||
instruction_schedule.at(allocation_value.defining_instruction());
|
||||
|
||||
absl::optional<int64> preferred_offset;
|
||||
auto preferred_offset_it =
|
||||
preferred_offset_for_computation.find(allocation_value.computation());
|
||||
if (preferred_offset_it != preferred_offset_for_computation.end()) {
|
||||
preferred_offset = preferred_offset_it->second;
|
||||
}
|
||||
|
||||
// Iterate over the uses.
|
||||
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
|
||||
const HloUse& use = allocation_value.uses().at(use_idx);
|
||||
int64 use_time = instruction_schedule.at(use.instruction);
|
||||
int64 latest_prefetch_time = use_time;
|
||||
bool allow_no_copy_alternate_mem_allocation = true;
|
||||
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
|
||||
|
||||
// Sequential calls include kWhile, kCall, and kConditional opcodes.
|
||||
bool is_sequential_call =
|
||||
(GetInstructionCallContext(use.instruction->opcode()) ==
|
||||
CallContext::kSequential);
|
||||
if (is_sequential_call) {
|
||||
for (const HloComputation* called_computation :
|
||||
use.instruction->called_computations()) {
|
||||
const HloLiveRange::TimeBound& computation_span =
|
||||
hlo_live_range_.computation_span_times().at(called_computation);
|
||||
latest_prefetch_time =
|
||||
std::min(computation_span.start, latest_prefetch_time);
|
||||
}
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Given an example while loop and flattened schedule (logical times
|
||||
// shown on the left):
|
||||
//
|
||||
// 0: a = ...
|
||||
// 1: ...
|
||||
// cond {
|
||||
// 2: p = param(0)
|
||||
// 3: ...
|
||||
// }
|
||||
// body {
|
||||
// 4: p = param(0)
|
||||
// 5: ...
|
||||
// 6: ROOT ...
|
||||
// }
|
||||
// 7: w = while(a), body=body, cond=cond
|
||||
//
|
||||
// When processing "a" (time 0) and its while use (time 7), we update
|
||||
// the interval to time 0-4. This is so that the remaining interval
|
||||
// (5-6) can be allocated separately and this buffer doesn't waste
|
||||
// alternate memory space within the while loop body.
|
||||
HloComputation* while_body = use.instruction->while_body();
|
||||
// We require while body ROOTs to be the last in the schedule.
|
||||
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
|
||||
instruction_schedule.at(use.instruction))
|
||||
<< "While body ROOTs need to be the last in the schedule! "
|
||||
"Please run RootInstructionSinker.";
|
||||
// Replace the use time with the parameter time so that we can decide
|
||||
// on alternate memory allocations within the while loop body when we
|
||||
// look at uses within the while loop body.
|
||||
use_time =
|
||||
instruction_schedule.at(while_body->parameter_instruction(0));
|
||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
// Replace the use time with the earliest parameter of called
|
||||
// computations.
|
||||
for (const HloComputation* called_computation :
|
||||
use.instruction->called_computations()) {
|
||||
use_time = std::min(
|
||||
use_time, instruction_schedule.at(
|
||||
called_computation->parameter_instruction(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a required assignment in default memory if the use not allowed in
|
||||
// alternate memory.
|
||||
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
|
||||
AddRequiredAssignment(allocation_value.value(), use.instruction,
|
||||
MemorySpace::kDefault, use_time);
|
||||
} else if (use_idx > 0) {
|
||||
// We allow buffers in alternate memory that are passed into
|
||||
// conditionals to give up their alternate memory allocation inside the
|
||||
// called computation. This means that if a conditional operator has an
|
||||
// alternate memory allocation, subsequent uses cannot use the same
|
||||
// alternate memory allocation in order not to clobber data. So we force
|
||||
// default memory allocation for these subsequent uses.
|
||||
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
|
||||
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
|
||||
previous_use.instruction != use.instruction) {
|
||||
allow_no_copy_alternate_mem_allocation = false;
|
||||
earliest_prefetch_time =
|
||||
instruction_schedule.at(previous_use.instruction);
|
||||
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
|
||||
<< use.ToString()
|
||||
<< ") is a conditional, so this use will need to evict. "
|
||||
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
||||
}
|
||||
}
|
||||
|
||||
// Bitcasts don't define buffers and don't directly consume buffers. Skip
|
||||
// allocating buffers for bitcast uses. The uses that feed from bitcasts
|
||||
// will be handled specially.
|
||||
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||
AllocationRequest request;
|
||||
// Rarely, (e.g., when conditional true and false parameters are the
|
||||
// same), definition time can be the time of the conditional and use
|
||||
// time is the parameter use, which is less.
|
||||
request.start_time = std::min(definition_time, use_time);
|
||||
request.end_time = use_time;
|
||||
request.latest_prefetch_time = latest_prefetch_time;
|
||||
request.size = colocated_intervals[0]->size;
|
||||
request.allow_no_copy_alternate_mem_allocation =
|
||||
allow_no_copy_alternate_mem_allocation;
|
||||
request.earliest_prefetch_time = earliest_prefetch_time;
|
||||
request.preferred_offset = preferred_offset;
|
||||
request.use = use;
|
||||
request.allocation_value = &allocation_value;
|
||||
if (!AllocateSegment(request)) {
|
||||
// If the allocation finding failed (e.g., due to running out of
|
||||
// asynchronous copies), then fall back to allocating the buffer
|
||||
// entirely in the default memory.
|
||||
UncommitPendingChunks();
|
||||
allocation_success = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are multiple uses, they can try using the memory allocation
|
||||
// already at the alternate memory.
|
||||
definition_time = instruction_schedule.at(use.instruction);
|
||||
}
|
||||
|
||||
// If the use has been a sequential call (e.g. a while loop), the other
|
||||
// colocated intervals must alias with this allocation.
|
||||
if (is_sequential_call) {
|
||||
MemorySpaceAssignment::Allocation* aliased_allocation =
|
||||
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
||||
use_time);
|
||||
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
|
||||
// Remember the preferred offset to be used inside while loop body
|
||||
// computations.
|
||||
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
|
||||
use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
preferred_offset_for_computation[use.instruction->while_body()] =
|
||||
aliased_allocation->chunk().offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!allocation_success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allocation_success) {
|
||||
for (AllocationValue& allocation_value : allocation_values) {
|
||||
for (auto& allocation : *allocation_value.allocation_sequence()) {
|
||||
AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation,
|
||||
&allocation_info_str_);
|
||||
allocations_->push_back(std::move(allocation));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pending_chunks_.clear();
|
||||
pending_async_copies_.clear();
|
||||
}
|
||||
|
||||
bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
|
||||
return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
|
||||
(a.start_time <= b.start_time && a.end_time < b.end_time);
|
||||
|
@ -1395,7 +1393,7 @@ AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
|
|||
return required_assignment_at_time;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||
bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
const AllocationRequest& request) {
|
||||
auto allocation_sequence = request.allocation_value->allocation_sequence();
|
||||
// start_time == end_time is a special case where the value is consumed
|
||||
|
|
|
@ -890,7 +890,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||
void CreateAllocationValues(const HloValue* value,
|
||||
std::vector<AllocationValue>* allocation_values);
|
||||
|
||||
// Finds an allocation for the given interval.
|
||||
// Finds allocations for colocated intervals. Colocated intervals consist of
|
||||
// one or more BufferIntervals, each with a different HloValue. All of the
|
||||
// intervals within colocated intervals have a must-alias relationship with
|
||||
// each other.
|
||||
void AllocateColocatedIntervals(
|
||||
const std::vector<const BufferInterval*>& colocated_intervals);
|
||||
|
||||
// Finds an allocation for an allocation request for a segment (see the
|
||||
// documentation for AllocationRequest above how a segment is defined).
|
||||
//
|
||||
// It performs three things in the following order:
|
||||
// 1- Allocate the allocation request entirely in the alternate memory, if
|
||||
|
@ -904,7 +912,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||
// false. This means we could not find a suitable allocation, so all previous
|
||||
// allocations for this buffer must be removed and allocated in the default
|
||||
// memory. Otherwise, this method returns true.
|
||||
bool FindAllocation(const AllocationRequest& request);
|
||||
bool AllocateSegment(const AllocationRequest& request);
|
||||
|
||||
// Try allocating in alternate memory without any copies. Returns true if
|
||||
// successful.
|
||||
|
@ -1000,8 +1008,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||
const BufferInterval& interval,
|
||||
const MemorySpaceAssignment::Allocation& allocation,
|
||||
std::string* debug_str) const;
|
||||
void DumpIfEnabled(absl::string_view buffer_info_str,
|
||||
absl::string_view allocation_info_str) const;
|
||||
void DumpDebugStringsIfEnabled() const;
|
||||
|
||||
// Returns the available heap size in the alternate memory.
|
||||
int64 available_heap_size() const {
|
||||
|
@ -1025,6 +1032,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||
required_assignments_;
|
||||
// Number of bytes reserved in alternate memory space.
|
||||
int64 reserved_in_bytes_ = 0;
|
||||
// Debug strings.
|
||||
std::string buffer_info_str_;
|
||||
std::string allocation_info_str_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
|
Loading…
Reference in New Issue