[XLA] NFC: Refactor to split a very large method.
PiperOrigin-RevId: 314738045 Change-Id: I07ac601281ff03f0209727fef446568d1df576b2
This commit is contained in:
parent
3897e02fc9
commit
47f7695c00
|
@ -787,14 +787,12 @@ void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AlternateMemoryBestFitHeap::DumpIfEnabled(
|
void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
|
||||||
absl::string_view buffer_info_str,
|
|
||||||
absl::string_view allocation_info_str) const {
|
|
||||||
if (!options_.dump_fn) {
|
if (!options_.dump_fn) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
options_.dump_fn("bufferinfo", buffer_info_str);
|
options_.dump_fn("bufferinfo", buffer_info_str_);
|
||||||
options_.dump_fn("allocinfo", allocation_info_str);
|
options_.dump_fn("allocinfo", allocation_info_str_);
|
||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
|
@ -816,9 +814,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string buffer_info_str;
|
|
||||||
std::string allocation_info_str;
|
|
||||||
|
|
||||||
for (auto& interval : sorted_buffer_intervals) {
|
for (auto& interval : sorted_buffer_intervals) {
|
||||||
if (!interval.need_allocation) {
|
if (!interval.need_allocation) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -842,12 +837,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto colocated_intervals = GetSortedColocatedIntervals(interval);
|
auto colocated_intervals = GetSortedColocatedIntervals(interval);
|
||||||
// Create AllocationValues for all the
|
|
||||||
// colocated intervals.
|
|
||||||
std::vector<AllocationValue> allocation_values;
|
|
||||||
for (const auto& colocated_interval : colocated_intervals) {
|
|
||||||
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
|
if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
|
||||||
VLOG(3) << "Interval " << interval.buffer->ToShortString()
|
VLOG(3) << "Interval " << interval.buffer->ToShortString()
|
||||||
|
@ -890,8 +879,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
|
||||||
|
|
||||||
// TODO(berkin): For now, place the phi values due to conditionals in
|
// TODO(berkin): For now, place the phi values due to conditionals in
|
||||||
// default memory.
|
// default memory.
|
||||||
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
||||||
|
@ -911,13 +898,36 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AppendBufferInfoDebugString(interval, &buffer_info_str);
|
AllocateColocatedIntervals(colocated_intervals);
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(3) << "Debug buffer info: ";
|
||||||
|
VLOG(3) << buffer_info_str_;
|
||||||
|
VLOG(3) << "Debug allocation info: ";
|
||||||
|
VLOG(3) << allocation_info_str_;
|
||||||
|
DumpDebugStringsIfEnabled();
|
||||||
|
|
||||||
|
return result_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||||
|
const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>&
|
||||||
|
colocated_intervals) {
|
||||||
|
// Create AllocationValues for all the colocated intervals.
|
||||||
|
std::vector<AllocationValue> allocation_values;
|
||||||
|
for (const auto& colocated_interval : colocated_intervals) {
|
||||||
|
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
||||||
|
}
|
||||||
|
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||||
|
|
||||||
// Data structure to contain the preferred offset for a given computation.
|
// Data structure to contain the preferred offset for a given computation.
|
||||||
// We ensure that the same offset will be allocated outside the while loop
|
// We ensure that the same offset will be allocated outside the while loop
|
||||||
// as well as inside the while loop.
|
// as well as inside the while loop.
|
||||||
absl::flat_hash_map<const HloComputation*, int64>
|
absl::flat_hash_map<const HloComputation*, int64>
|
||||||
preferred_offset_for_computation;
|
preferred_offset_for_computation;
|
||||||
|
|
||||||
|
AppendBufferInfoDebugString(*colocated_intervals[0], &buffer_info_str_);
|
||||||
|
|
||||||
bool allocation_success = true;
|
bool allocation_success = true;
|
||||||
for (auto& allocation_value : allocation_values) {
|
for (auto& allocation_value : allocation_values) {
|
||||||
int64 definition_time =
|
int64 definition_time =
|
||||||
|
@ -931,8 +941,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate over the uses.
|
// Iterate over the uses.
|
||||||
for (int use_idx = 0; use_idx < allocation_value.uses().size();
|
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
|
||||||
++use_idx) {
|
|
||||||
const HloUse& use = allocation_value.uses().at(use_idx);
|
const HloUse& use = allocation_value.uses().at(use_idx);
|
||||||
int64 use_time = instruction_schedule.at(use.instruction);
|
int64 use_time = instruction_schedule.at(use.instruction);
|
||||||
int64 latest_prefetch_time = use_time;
|
int64 latest_prefetch_time = use_time;
|
||||||
|
@ -968,20 +977,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
// }
|
// }
|
||||||
// 7: w = while(a), body=body, cond=cond
|
// 7: w = while(a), body=body, cond=cond
|
||||||
//
|
//
|
||||||
// When processing "a" (time 0) and its while use (time 7), we
|
// When processing "a" (time 0) and its while use (time 7), we update
|
||||||
// update the interval to time 0-4. This is so that the remaining
|
// the interval to time 0-4. This is so that the remaining interval
|
||||||
// interval (5-6) can be allocated separately and this buffer
|
// (5-6) can be allocated separately and this buffer doesn't waste
|
||||||
// doesn't waste alternate memory space within the while loop body.
|
// alternate memory space within the while loop body.
|
||||||
HloComputation* while_body = use.instruction->while_body();
|
HloComputation* while_body = use.instruction->while_body();
|
||||||
// We require while body ROOTs to be the last in the schedule.
|
// We require while body ROOTs to be the last in the schedule.
|
||||||
CHECK_EQ(
|
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
|
||||||
instruction_schedule.at(while_body->root_instruction()) + 1,
|
|
||||||
instruction_schedule.at(use.instruction))
|
instruction_schedule.at(use.instruction))
|
||||||
<< "While body ROOTs need to be the last in the schedule! "
|
<< "While body ROOTs need to be the last in the schedule! "
|
||||||
"Please run RootInstructionSinker.";
|
"Please run RootInstructionSinker.";
|
||||||
// Replace the use time with the parameter time so that we can
|
// Replace the use time with the parameter time so that we can decide
|
||||||
// decide on alternate memory allocations within the while loop body
|
// on alternate memory allocations within the while loop body when we
|
||||||
// when we look at uses within the while loop body.
|
// look at uses within the while loop body.
|
||||||
use_time =
|
use_time =
|
||||||
instruction_schedule.at(while_body->parameter_instruction(0));
|
instruction_schedule.at(while_body->parameter_instruction(0));
|
||||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||||
|
@ -1003,27 +1011,27 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
MemorySpace::kDefault, use_time);
|
MemorySpace::kDefault, use_time);
|
||||||
} else if (use_idx > 0) {
|
} else if (use_idx > 0) {
|
||||||
// We allow buffers in alternate memory that are passed into
|
// We allow buffers in alternate memory that are passed into
|
||||||
// conditionals to give up their alternate memory allocation inside
|
// conditionals to give up their alternate memory allocation inside the
|
||||||
// the called computation. This means that if a conditional operator
|
// called computation. This means that if a conditional operator has an
|
||||||
// has an alternate memory allocation, subsequent uses cannot use the
|
// alternate memory allocation, subsequent uses cannot use the same
|
||||||
// same alternate memory allocation in order not to clobber data. So
|
// alternate memory allocation in order not to clobber data. So we force
|
||||||
// we force default memory allocation for these subsequent uses.
|
// default memory allocation for these subsequent uses.
|
||||||
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
|
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
|
||||||
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
|
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
|
||||||
previous_use.instruction != use.instruction) {
|
previous_use.instruction != use.instruction) {
|
||||||
allow_no_copy_alternate_mem_allocation = false;
|
allow_no_copy_alternate_mem_allocation = false;
|
||||||
earliest_prefetch_time =
|
earliest_prefetch_time =
|
||||||
instruction_schedule.at(previous_use.instruction);
|
instruction_schedule.at(previous_use.instruction);
|
||||||
VLOG(3) << "Previous use (" << previous_use.ToString()
|
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
|
||||||
<< ") of use (" << use.ToString()
|
<< use.ToString()
|
||||||
<< ") is a conditional, so this use will need to evict. "
|
<< ") is a conditional, so this use will need to evict. "
|
||||||
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bitcasts don't define buffers and don't directly consume buffers.
|
// Bitcasts don't define buffers and don't directly consume buffers. Skip
|
||||||
// Skip allocating buffers for bitcast uses. The uses that feed from
|
// allocating buffers for bitcast uses. The uses that feed from bitcasts
|
||||||
// bitcasts will be handled specially.
|
// will be handled specially.
|
||||||
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||||
AllocationRequest request;
|
AllocationRequest request;
|
||||||
// Rarely, (e.g., when conditional true and false parameters are the
|
// Rarely, (e.g., when conditional true and false parameters are the
|
||||||
|
@ -1032,14 +1040,14 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
request.start_time = std::min(definition_time, use_time);
|
request.start_time = std::min(definition_time, use_time);
|
||||||
request.end_time = use_time;
|
request.end_time = use_time;
|
||||||
request.latest_prefetch_time = latest_prefetch_time;
|
request.latest_prefetch_time = latest_prefetch_time;
|
||||||
request.size = interval.size;
|
request.size = colocated_intervals[0]->size;
|
||||||
request.allow_no_copy_alternate_mem_allocation =
|
request.allow_no_copy_alternate_mem_allocation =
|
||||||
allow_no_copy_alternate_mem_allocation;
|
allow_no_copy_alternate_mem_allocation;
|
||||||
request.earliest_prefetch_time = earliest_prefetch_time;
|
request.earliest_prefetch_time = earliest_prefetch_time;
|
||||||
request.preferred_offset = preferred_offset;
|
request.preferred_offset = preferred_offset;
|
||||||
request.use = use;
|
request.use = use;
|
||||||
request.allocation_value = &allocation_value;
|
request.allocation_value = &allocation_value;
|
||||||
if (!FindAllocation(request)) {
|
if (!AllocateSegment(request)) {
|
||||||
// If the allocation finding failed (e.g., due to running out of
|
// If the allocation finding failed (e.g., due to running out of
|
||||||
// asynchronous copies), then fall back to allocating the buffer
|
// asynchronous copies), then fall back to allocating the buffer
|
||||||
// entirely in the default memory.
|
// entirely in the default memory.
|
||||||
|
@ -1048,8 +1056,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are multiple uses, they can try using the memory
|
// If there are multiple uses, they can try using the memory allocation
|
||||||
// allocation already at the alternate memory.
|
// already at the alternate memory.
|
||||||
definition_time = instruction_schedule.at(use.instruction);
|
definition_time = instruction_schedule.at(use.instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1059,8 +1067,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
MemorySpaceAssignment::Allocation* aliased_allocation =
|
MemorySpaceAssignment::Allocation* aliased_allocation =
|
||||||
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
||||||
use_time);
|
use_time);
|
||||||
AddAliasedRequiredAssignmentsForSequentialCall(use,
|
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
|
||||||
aliased_allocation);
|
|
||||||
// Remember the preferred offset to be used inside while loop body
|
// Remember the preferred offset to be used inside while loop body
|
||||||
// computations.
|
// computations.
|
||||||
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
|
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
|
||||||
|
@ -1077,8 +1084,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
if (allocation_success) {
|
if (allocation_success) {
|
||||||
for (AllocationValue& allocation_value : allocation_values) {
|
for (AllocationValue& allocation_value : allocation_values) {
|
||||||
for (auto& allocation : *allocation_value.allocation_sequence()) {
|
for (auto& allocation : *allocation_value.allocation_sequence()) {
|
||||||
AppendAllocationInfoDebugString(interval, *allocation,
|
AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation,
|
||||||
&allocation_info_str);
|
&allocation_info_str_);
|
||||||
allocations_->push_back(std::move(allocation));
|
allocations_->push_back(std::move(allocation));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1088,15 +1095,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||||
pending_async_copies_.clear();
|
pending_async_copies_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(3) << "Debug buffer info: ";
|
|
||||||
VLOG(3) << buffer_info_str;
|
|
||||||
VLOG(3) << "Debug allocation info: ";
|
|
||||||
VLOG(3) << allocation_info_str;
|
|
||||||
DumpIfEnabled(buffer_info_str, allocation_info_str);
|
|
||||||
|
|
||||||
return result_;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
|
bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
|
||||||
return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
|
return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
|
||||||
(a.start_time <= b.start_time && a.end_time < b.end_time);
|
(a.start_time <= b.start_time && a.end_time < b.end_time);
|
||||||
|
@ -1395,7 +1393,7 @@ AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
|
||||||
return required_assignment_at_time;
|
return required_assignment_at_time;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AlternateMemoryBestFitHeap::FindAllocation(
|
bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||||
const AllocationRequest& request) {
|
const AllocationRequest& request) {
|
||||||
auto allocation_sequence = request.allocation_value->allocation_sequence();
|
auto allocation_sequence = request.allocation_value->allocation_sequence();
|
||||||
// start_time == end_time is a special case where the value is consumed
|
// start_time == end_time is a special case where the value is consumed
|
||||||
|
|
|
@ -890,7 +890,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||||
void CreateAllocationValues(const HloValue* value,
|
void CreateAllocationValues(const HloValue* value,
|
||||||
std::vector<AllocationValue>* allocation_values);
|
std::vector<AllocationValue>* allocation_values);
|
||||||
|
|
||||||
// Finds an allocation for the given interval.
|
// Finds allocations for colocated intervals. Colocated intervals consist of
|
||||||
|
// one or more BufferIntervals, each with a different HloValue. All of the
|
||||||
|
// intervals within colocated intervals have a must-alias relationship with
|
||||||
|
// each other.
|
||||||
|
void AllocateColocatedIntervals(
|
||||||
|
const std::vector<const BufferInterval*>& colocated_intervals);
|
||||||
|
|
||||||
|
// Finds an allocation for an allocation request for a segment (see the
|
||||||
|
// documentation for AllocationRequest above how a segment is defined).
|
||||||
//
|
//
|
||||||
// It performs three things in the following order:
|
// It performs three things in the following order:
|
||||||
// 1- Allocate the allocation request entirely in the alternate memory, if
|
// 1- Allocate the allocation request entirely in the alternate memory, if
|
||||||
|
@ -904,7 +912,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||||
// false. This means we could not find a suitable allocation, so all previous
|
// false. This means we could not find a suitable allocation, so all previous
|
||||||
// allocations for this buffer must be removed and allocated in the default
|
// allocations for this buffer must be removed and allocated in the default
|
||||||
// memory. Otherwise, this method returns true.
|
// memory. Otherwise, this method returns true.
|
||||||
bool FindAllocation(const AllocationRequest& request);
|
bool AllocateSegment(const AllocationRequest& request);
|
||||||
|
|
||||||
// Try allocating in alternate memory without any copies. Returns true if
|
// Try allocating in alternate memory without any copies. Returns true if
|
||||||
// successful.
|
// successful.
|
||||||
|
@ -1000,8 +1008,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||||
const BufferInterval& interval,
|
const BufferInterval& interval,
|
||||||
const MemorySpaceAssignment::Allocation& allocation,
|
const MemorySpaceAssignment::Allocation& allocation,
|
||||||
std::string* debug_str) const;
|
std::string* debug_str) const;
|
||||||
void DumpIfEnabled(absl::string_view buffer_info_str,
|
void DumpDebugStringsIfEnabled() const;
|
||||||
absl::string_view allocation_info_str) const;
|
|
||||||
|
|
||||||
// Returns the available heap size in the alternate memory.
|
// Returns the available heap size in the alternate memory.
|
||||||
int64 available_heap_size() const {
|
int64 available_heap_size() const {
|
||||||
|
@ -1025,6 +1032,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||||
required_assignments_;
|
required_assignments_;
|
||||||
// Number of bytes reserved in alternate memory space.
|
// Number of bytes reserved in alternate memory space.
|
||||||
int64 reserved_in_bytes_ = 0;
|
int64 reserved_in_bytes_ = 0;
|
||||||
|
// Debug strings.
|
||||||
|
std::string buffer_info_str_;
|
||||||
|
std::string allocation_info_str_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
Loading…
Reference in New Issue