[XLA] NFC: Refactor to split a very large method.

PiperOrigin-RevId: 314738045
Change-Id: I07ac601281ff03f0209727fef446568d1df576b2
This commit is contained in:
Berkin Ilbeyi 2020-06-04 08:57:40 -07:00 committed by TensorFlower Gardener
parent 3897e02fc9
commit 47f7695c00
2 changed files with 207 additions and 199 deletions

View File

@ -787,14 +787,12 @@ void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
}
}
void AlternateMemoryBestFitHeap::DumpIfEnabled(
absl::string_view buffer_info_str,
absl::string_view allocation_info_str) const {
void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
if (!options_.dump_fn) {
return;
}
options_.dump_fn("bufferinfo", buffer_info_str);
options_.dump_fn("allocinfo", allocation_info_str);
options_.dump_fn("bufferinfo", buffer_info_str_);
options_.dump_fn("allocinfo", allocation_info_str_);
}
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
@ -816,9 +814,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
}
std::string buffer_info_str;
std::string allocation_info_str;
for (auto& interval : sorted_buffer_intervals) {
if (!interval.need_allocation) {
continue;
@ -842,12 +837,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
auto colocated_intervals = GetSortedColocatedIntervals(interval);
// Create AllocationValues for all the
// colocated intervals.
std::vector<AllocationValue> allocation_values;
for (const auto& colocated_interval : colocated_intervals) {
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
}
if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
VLOG(3) << "Interval " << interval.buffer->ToShortString()
@ -890,8 +879,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
// TODO(berkin): For now, place the phi values due to conditionals in
// default memory.
for (const BufferInterval* colocated_interval : colocated_intervals) {
@ -911,192 +898,203 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
}
AppendBufferInfoDebugString(interval, &buffer_info_str);
// Data structure to contain the preferred offset for a given computation.
// We ensure that the same offset will be allocated outside the while loop
// as well as inside the while loop.
absl::flat_hash_map<const HloComputation*, int64>
preferred_offset_for_computation;
bool allocation_success = true;
for (auto& allocation_value : allocation_values) {
int64 definition_time =
instruction_schedule.at(allocation_value.defining_instruction());
absl::optional<int64> preferred_offset;
auto preferred_offset_it =
preferred_offset_for_computation.find(allocation_value.computation());
if (preferred_offset_it != preferred_offset_for_computation.end()) {
preferred_offset = preferred_offset_it->second;
}
// Iterate over the uses.
for (int use_idx = 0; use_idx < allocation_value.uses().size();
++use_idx) {
const HloUse& use = allocation_value.uses().at(use_idx);
int64 use_time = instruction_schedule.at(use.instruction);
int64 latest_prefetch_time = use_time;
bool allow_no_copy_alternate_mem_allocation = true;
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
// Sequential calls include kWhile, kCall, and kConditional opcodes.
bool is_sequential_call =
(GetInstructionCallContext(use.instruction->opcode()) ==
CallContext::kSequential);
if (is_sequential_call) {
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
const HloLiveRange::TimeBound& computation_span =
hlo_live_range_.computation_span_times().at(called_computation);
latest_prefetch_time =
std::min(computation_span.start, latest_prefetch_time);
}
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Given an example while loop and flattened schedule (logical times
// shown on the left):
//
// 0: a = ...
// 1: ...
// cond {
// 2: p = param(0)
// 3: ...
// }
// body {
// 4: p = param(0)
// 5: ...
// 6: ROOT ...
// }
// 7: w = while(a), body=body, cond=cond
//
// When processing "a" (time 0) and its while use (time 7), we
// update the interval to time 0-4. This is so that the remaining
// interval (5-6) can be allocated separately and this buffer
// doesn't waste alternate memory space within the while loop body.
HloComputation* while_body = use.instruction->while_body();
// We require while body ROOTs to be the last in the schedule.
CHECK_EQ(
instruction_schedule.at(while_body->root_instruction()) + 1,
instruction_schedule.at(use.instruction))
<< "While body ROOTs need to be the last in the schedule! "
"Please run RootInstructionSinker.";
// Replace the use time with the parameter time so that we can
// decide on alternate memory allocations within the while loop body
// when we look at uses within the while loop body.
use_time =
instruction_schedule.at(while_body->parameter_instruction(0));
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
// Replace the use time with the earliest parameter of called
// computations.
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
use_time = std::min(
use_time, instruction_schedule.at(
called_computation->parameter_instruction(0)));
}
}
}
// Add a required assignment in default memory if the use not allowed in
// alternate memory.
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
AddRequiredAssignment(allocation_value.value(), use.instruction,
MemorySpace::kDefault, use_time);
} else if (use_idx > 0) {
// We allow buffers in alternate memory that are passed into
// conditionals to give up their alternate memory allocation inside
// the called computation. This means that if a conditional operator
// has an alternate memory allocation, subsequent uses cannot use the
// same alternate memory allocation in order not to clobber data. So
// we force default memory allocation for these subsequent uses.
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
previous_use.instruction != use.instruction) {
allow_no_copy_alternate_mem_allocation = false;
earliest_prefetch_time =
instruction_schedule.at(previous_use.instruction);
VLOG(3) << "Previous use (" << previous_use.ToString()
<< ") of use (" << use.ToString()
<< ") is a conditional, so this use will need to evict. "
<< "Earliest prefetch time = " << *earliest_prefetch_time;
}
}
// Bitcasts don't define buffers and don't directly consume buffers.
// Skip allocating buffers for bitcast uses. The uses that feed from
// bitcasts will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
AllocationRequest request;
// Rarely, (e.g., when conditional true and false parameters are the
// same), definition time can be the time of the conditional and use
// time is the parameter use, which is less.
request.start_time = std::min(definition_time, use_time);
request.end_time = use_time;
request.latest_prefetch_time = latest_prefetch_time;
request.size = interval.size;
request.allow_no_copy_alternate_mem_allocation =
allow_no_copy_alternate_mem_allocation;
request.earliest_prefetch_time = earliest_prefetch_time;
request.preferred_offset = preferred_offset;
request.use = use;
request.allocation_value = &allocation_value;
if (!FindAllocation(request)) {
// If the allocation finding failed (e.g., due to running out of
// asynchronous copies), then fall back to allocating the buffer
// entirely in the default memory.
UncommitPendingChunks();
allocation_success = false;
break;
}
// If there are multiple uses, they can try using the memory
// allocation already at the alternate memory.
definition_time = instruction_schedule.at(use.instruction);
}
// If the use has been a sequential call (e.g. a while loop), the other
// colocated intervals must alias with this allocation.
if (is_sequential_call) {
MemorySpaceAssignment::Allocation* aliased_allocation =
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
use_time);
AddAliasedRequiredAssignmentsForSequentialCall(use,
aliased_allocation);
// Remember the preferred offset to be used inside while loop body
// computations.
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
use.instruction->opcode() == HloOpcode::kWhile) {
preferred_offset_for_computation[use.instruction->while_body()] =
aliased_allocation->chunk().offset;
}
}
}
if (!allocation_success) {
break;
}
}
if (allocation_success) {
for (AllocationValue& allocation_value : allocation_values) {
for (auto& allocation : *allocation_value.allocation_sequence()) {
AppendAllocationInfoDebugString(interval, *allocation,
&allocation_info_str);
allocations_->push_back(std::move(allocation));
}
}
}
pending_chunks_.clear();
pending_async_copies_.clear();
AllocateColocatedIntervals(colocated_intervals);
}
VLOG(3) << "Debug buffer info: ";
VLOG(3) << buffer_info_str;
VLOG(3) << buffer_info_str_;
VLOG(3) << "Debug allocation info: ";
VLOG(3) << allocation_info_str;
DumpIfEnabled(buffer_info_str, allocation_info_str);
VLOG(3) << allocation_info_str_;
DumpDebugStringsIfEnabled();
return result_;
}
void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>&
colocated_intervals) {
// Create AllocationValues for all the colocated intervals.
std::vector<AllocationValue> allocation_values;
for (const auto& colocated_interval : colocated_intervals) {
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
}
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
// Data structure to contain the preferred offset for a given computation.
// We ensure that the same offset will be allocated outside the while loop
// as well as inside the while loop.
absl::flat_hash_map<const HloComputation*, int64>
preferred_offset_for_computation;
AppendBufferInfoDebugString(*colocated_intervals[0], &buffer_info_str_);
bool allocation_success = true;
for (auto& allocation_value : allocation_values) {
int64 definition_time =
instruction_schedule.at(allocation_value.defining_instruction());
absl::optional<int64> preferred_offset;
auto preferred_offset_it =
preferred_offset_for_computation.find(allocation_value.computation());
if (preferred_offset_it != preferred_offset_for_computation.end()) {
preferred_offset = preferred_offset_it->second;
}
// Iterate over the uses.
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
const HloUse& use = allocation_value.uses().at(use_idx);
int64 use_time = instruction_schedule.at(use.instruction);
int64 latest_prefetch_time = use_time;
bool allow_no_copy_alternate_mem_allocation = true;
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
// Sequential calls include kWhile, kCall, and kConditional opcodes.
bool is_sequential_call =
(GetInstructionCallContext(use.instruction->opcode()) ==
CallContext::kSequential);
if (is_sequential_call) {
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
const HloLiveRange::TimeBound& computation_span =
hlo_live_range_.computation_span_times().at(called_computation);
latest_prefetch_time =
std::min(computation_span.start, latest_prefetch_time);
}
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Given an example while loop and flattened schedule (logical times
// shown on the left):
//
// 0: a = ...
// 1: ...
// cond {
// 2: p = param(0)
// 3: ...
// }
// body {
// 4: p = param(0)
// 5: ...
// 6: ROOT ...
// }
// 7: w = while(a), body=body, cond=cond
//
// When processing "a" (time 0) and its while use (time 7), we update
// the interval to time 0-4. This is so that the remaining interval
// (5-6) can be allocated separately and this buffer doesn't waste
// alternate memory space within the while loop body.
HloComputation* while_body = use.instruction->while_body();
// We require while body ROOTs to be the last in the schedule.
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
instruction_schedule.at(use.instruction))
<< "While body ROOTs need to be the last in the schedule! "
"Please run RootInstructionSinker.";
// Replace the use time with the parameter time so that we can decide
// on alternate memory allocations within the while loop body when we
// look at uses within the while loop body.
use_time =
instruction_schedule.at(while_body->parameter_instruction(0));
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
// Replace the use time with the earliest parameter of called
// computations.
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
use_time = std::min(
use_time, instruction_schedule.at(
called_computation->parameter_instruction(0)));
}
}
}
// Add a required assignment in default memory if the use not allowed in
// alternate memory.
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
AddRequiredAssignment(allocation_value.value(), use.instruction,
MemorySpace::kDefault, use_time);
} else if (use_idx > 0) {
// We allow buffers in alternate memory that are passed into
// conditionals to give up their alternate memory allocation inside the
// called computation. This means that if a conditional operator has an
// alternate memory allocation, subsequent uses cannot use the same
// alternate memory allocation in order not to clobber data. So we force
// default memory allocation for these subsequent uses.
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
previous_use.instruction != use.instruction) {
allow_no_copy_alternate_mem_allocation = false;
earliest_prefetch_time =
instruction_schedule.at(previous_use.instruction);
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
<< use.ToString()
<< ") is a conditional, so this use will need to evict. "
<< "Earliest prefetch time = " << *earliest_prefetch_time;
}
}
// Bitcasts don't define buffers and don't directly consume buffers. Skip
// allocating buffers for bitcast uses. The uses that feed from bitcasts
// will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
AllocationRequest request;
// Rarely, (e.g., when conditional true and false parameters are the
// same), definition time can be the time of the conditional and use
// time is the parameter use, which is less.
request.start_time = std::min(definition_time, use_time);
request.end_time = use_time;
request.latest_prefetch_time = latest_prefetch_time;
request.size = colocated_intervals[0]->size;
request.allow_no_copy_alternate_mem_allocation =
allow_no_copy_alternate_mem_allocation;
request.earliest_prefetch_time = earliest_prefetch_time;
request.preferred_offset = preferred_offset;
request.use = use;
request.allocation_value = &allocation_value;
if (!AllocateSegment(request)) {
// If the allocation finding failed (e.g., due to running out of
// asynchronous copies), then fall back to allocating the buffer
// entirely in the default memory.
UncommitPendingChunks();
allocation_success = false;
break;
}
// If there are multiple uses, they can try using the memory allocation
// already at the alternate memory.
definition_time = instruction_schedule.at(use.instruction);
}
// If the use has been a sequential call (e.g. a while loop), the other
// colocated intervals must alias with this allocation.
if (is_sequential_call) {
MemorySpaceAssignment::Allocation* aliased_allocation =
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
use_time);
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
// Remember the preferred offset to be used inside while loop body
// computations.
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
use.instruction->opcode() == HloOpcode::kWhile) {
preferred_offset_for_computation[use.instruction->while_body()] =
aliased_allocation->chunk().offset;
}
}
}
if (!allocation_success) {
break;
}
}
if (allocation_success) {
for (AllocationValue& allocation_value : allocation_values) {
for (auto& allocation : *allocation_value.allocation_sequence()) {
AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation,
&allocation_info_str_);
allocations_->push_back(std::move(allocation));
}
}
}
pending_chunks_.clear();
pending_async_copies_.clear();
}
bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
(a.start_time <= b.start_time && a.end_time < b.end_time);
@ -1395,7 +1393,7 @@ AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
return required_assignment_at_time;
}
bool AlternateMemoryBestFitHeap::FindAllocation(
bool AlternateMemoryBestFitHeap::AllocateSegment(
const AllocationRequest& request) {
auto allocation_sequence = request.allocation_value->allocation_sequence();
// start_time == end_time is a special case where the value is consumed

View File

@ -890,7 +890,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
void CreateAllocationValues(const HloValue* value,
std::vector<AllocationValue>* allocation_values);
// Finds an allocation for the given interval.
// Finds allocations for colocated intervals. Colocated intervals consist of
// one or more BufferIntervals, each with a different HloValue. All of the
// intervals within colocated intervals have a must-alias relationship with
// each other.
void AllocateColocatedIntervals(
const std::vector<const BufferInterval*>& colocated_intervals);
// Finds an allocation for an allocation request for a segment (see the
// documentation for AllocationRequest above how a segment is defined).
//
// It performs three things in the following order:
// 1- Allocate the allocation request entirely in the alternate memory, if
@ -904,7 +912,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// false. This means we could not find a suitable allocation, so all previous
// allocations for this buffer must be removed and allocated in the default
// memory. Otherwise, this method returns true.
bool FindAllocation(const AllocationRequest& request);
bool AllocateSegment(const AllocationRequest& request);
// Try allocating in alternate memory without any copies. Returns true if
// successful.
@ -1000,8 +1008,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
const BufferInterval& interval,
const MemorySpaceAssignment::Allocation& allocation,
std::string* debug_str) const;
void DumpIfEnabled(absl::string_view buffer_info_str,
absl::string_view allocation_info_str) const;
void DumpDebugStringsIfEnabled() const;
// Returns the available heap size in the alternate memory.
int64 available_heap_size() const {
@ -1025,6 +1032,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
required_assignments_;
// Number of bytes reserved in alternate memory space.
int64 reserved_in_bytes_ = 0;
// Debug strings.
std::string buffer_info_str_;
std::string allocation_info_str_;
};
} // namespace xla