[XLA] Implement memory space allocation across sequential calls (e.g. while).
For now, the heuristics aren't very good and also need to allow moving the buffer between memory spaces inside the while body as well. PiperOrigin-RevId: 281575698 Change-Id: I7c50a4ea4001021e0de44ff7643cc7a7cd44d7bd
This commit is contained in:
parent
9401effbed
commit
162048befe
@ -83,6 +83,12 @@ class HloLiveRange {
|
|||||||
return buffer_live_ranges_;
|
return buffer_live_ranges_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the map from a computation and its time span in the schedule.
|
||||||
|
const absl::flat_hash_map<const HloComputation*, TimeBound>&
|
||||||
|
computation_span_times() const {
|
||||||
|
return computation_span_times_;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the time stamp of the end of the program.
|
// Returns the time stamp of the end of the program.
|
||||||
LogicalTime schedule_end_time() const { return schedule_end_time_; }
|
LogicalTime schedule_end_time() const { return schedule_end_time_; }
|
||||||
|
|
||||||
|
@ -237,24 +237,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto colocated_intervals = GetSortedColocatedIntervals(interval);
|
auto colocated_intervals = GetSortedColocatedIntervals(interval);
|
||||||
bool keep_in_default_memory = false;
|
|
||||||
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
if (colocated_intervals.size() > 1 &&
|
||||||
const HloValue* value = colocated_interval->buffer;
|
!options_.allocate_across_sequential_calls) {
|
||||||
// If any of the colocated values are phi buffers, we keep them in the
|
VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
|
||||||
// default memory for now.
|
<< " because it aliases with another interval and "
|
||||||
if (value->is_phi()) {
|
<< " allocate_across_sequential_calls is false.";
|
||||||
keep_in_default_memory = true;
|
continue;
|
||||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
|
||||||
<< " because it contains a phi node.";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point, none of the colocated buffers contain any phi buffers.
|
const HloComputation* defining_computation =
|
||||||
|
colocated_intervals[0]->buffer->defining_instruction()->parent();
|
||||||
|
MemorySpaceAssignment::Allocation* aliased_allocation = nullptr;
|
||||||
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
for (const BufferInterval* colocated_interval : colocated_intervals) {
|
||||||
if (keep_in_default_memory) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
const HloValue* value = colocated_interval->buffer;
|
const HloValue* value = colocated_interval->buffer;
|
||||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||||
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
|
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
|
||||||
@ -267,25 +262,66 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||||||
return instruction_schedule.at(use1.instruction) <
|
return instruction_schedule.at(use1.instruction) <
|
||||||
instruction_schedule.at(use2.instruction);
|
instruction_schedule.at(use2.instruction);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// If there was an aliased allocation for this buffer, propagate that for
|
||||||
|
// this HloValue.
|
||||||
|
if (aliased_allocation != nullptr) {
|
||||||
|
VLOG(3) << "Adding an aliased allocation: ("
|
||||||
|
<< aliased_allocation->start_time() << ", "
|
||||||
|
<< aliased_allocation->end_time()
|
||||||
|
<< ") pos: " << aliased_allocation->defining_position()
|
||||||
|
<< " mem space: "
|
||||||
|
<< (aliased_allocation->memory_space() == MemorySpace::kDefault
|
||||||
|
? "default"
|
||||||
|
: "alt");
|
||||||
|
allocation_sequence->push_back(
|
||||||
|
absl::make_unique<MemorySpaceAssignment::Allocation>(
|
||||||
|
value->defining_instruction(), value->defining_position(),
|
||||||
|
aliased_allocation->memory_space(), aliased_allocation->chunk(),
|
||||||
|
aliased_allocation->start_time(),
|
||||||
|
aliased_allocation->end_time()));
|
||||||
|
}
|
||||||
|
|
||||||
// Iterate over the uses.
|
// Iterate over the uses.
|
||||||
for (HloUse use : uses) {
|
for (HloUse use : uses) {
|
||||||
int64 use_time = instruction_schedule.at(use.instruction);
|
int64 use_time = instruction_schedule.at(use.instruction);
|
||||||
int64 last_use_time = instruction_schedule.at(uses.back().instruction);
|
int64 last_use_time = instruction_schedule.at(uses.back().instruction);
|
||||||
|
int64 latest_prefetch_time = use_time;
|
||||||
|
|
||||||
|
if (use.instruction->parent() != defining_computation) {
|
||||||
|
VLOG(3) << "skip use " << use.ToString()
|
||||||
|
<< " because it's in a different computation.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sequential calls include kWhile, kCall, and kConditional opcodes.
|
||||||
|
bool is_sequential_call =
|
||||||
|
(GetInstructionCallContext(use.instruction->opcode()) ==
|
||||||
|
CallContext::kSequential);
|
||||||
|
if (is_sequential_call) {
|
||||||
|
for (const HloComputation* called_computation :
|
||||||
|
use.instruction->called_computations()) {
|
||||||
|
const HloLiveRange::TimeBound& computation_span =
|
||||||
|
hlo_live_range_.computation_span_times().at(called_computation);
|
||||||
|
latest_prefetch_time =
|
||||||
|
std::min(computation_span.start, latest_prefetch_time);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Bitcasts don't define buffers and don't directly consume buffers.
|
// Bitcasts don't define buffers and don't directly consume buffers.
|
||||||
// Skip allocating buffers for bitcast uses. The uses that feed from
|
// Skip allocating buffers for bitcast uses. The uses that feed from
|
||||||
// bitcasts will be handled specially.
|
// bitcasts will be handled specially.
|
||||||
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||||
if (!FindAllocation(definition_time, use_time, last_use_time,
|
if (!FindAllocation(definition_time, use_time, last_use_time,
|
||||||
value->defining_position(), use, value,
|
latest_prefetch_time, value->defining_position(),
|
||||||
colocated_interval->size, allocation_sequence)) {
|
use, value, colocated_interval->size,
|
||||||
|
allocation_sequence)) {
|
||||||
// If the allocation finding failed (e.g., due to running out of
|
// If the allocation finding failed (e.g., due to running out of
|
||||||
// asynchronous copies), then fall back to allocating the buffer
|
// asynchronous copies), then fall back to allocating the buffer
|
||||||
// entirely in the default memory.
|
// entirely in the default memory.
|
||||||
pending_chunks_.clear();
|
pending_chunks_.clear();
|
||||||
pending_async_copies_.clear();
|
pending_async_copies_.clear();
|
||||||
allocation_sequence->clear();
|
allocation_sequence->clear();
|
||||||
keep_in_default_memory = true;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,6 +329,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
|||||||
// allocation already at the alternate memory.
|
// allocation already at the alternate memory.
|
||||||
definition_time = use_time;
|
definition_time = use_time;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the use has been a sequential call (e.g. a while loop), the other
|
||||||
|
// colocated intervals must alias with this allocation.
|
||||||
|
if (is_sequential_call && !allocation_sequence->empty()) {
|
||||||
|
aliased_allocation = allocation_sequence->back().get();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -390,8 +432,9 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks(
|
|||||||
|
|
||||||
bool AlternateMemoryBestFitHeap::FindAllocation(
|
bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||||
int64 start_time, int64 end_time, int64 last_use_time,
|
int64 start_time, int64 end_time, int64 last_use_time,
|
||||||
HloPosition defining_position, HloUse use, const HloValue* buffer,
|
int64 latest_prefetch_time, HloPosition defining_position, HloUse use,
|
||||||
int64 size, MemorySpaceAssignment::AllocationSequence* allocations) {
|
const HloValue* buffer, int64 size,
|
||||||
|
MemorySpaceAssignment::AllocationSequence* allocations) {
|
||||||
HloInstruction* operand =
|
HloInstruction* operand =
|
||||||
use.instruction->mutable_operand(use.operand_number);
|
use.instruction->mutable_operand(use.operand_number);
|
||||||
// If the operand is a bitcast, we look at bitcast's operand until we find a
|
// If the operand is a bitcast, we look at bitcast's operand until we find a
|
||||||
@ -408,8 +451,10 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
alternate_mem_interval.end = end_time;
|
alternate_mem_interval.end = end_time;
|
||||||
|
|
||||||
VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
|
VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
|
||||||
<< start_time << ", " << end_time << ") last use = " << last_use_time
|
<< start_time << ", " << end_time
|
||||||
<< " use = " << use.ToString() << ". Size = " << size
|
<< ") latest prefetch = " << latest_prefetch_time
|
||||||
|
<< " last use = " << last_use_time << " use = " << use.ToString()
|
||||||
|
<< ". Size = " << size
|
||||||
<< ", def pos = " << defining_position.ToString()
|
<< ", def pos = " << defining_position.ToString()
|
||||||
<< ", operand = " << operand->ToShortString()
|
<< ", operand = " << operand->ToShortString()
|
||||||
<< (non_bitcast_operand != operand
|
<< (non_bitcast_operand != operand
|
||||||
@ -445,19 +490,6 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(berkin): This is curently overly restrictive and will fail using
|
|
||||||
// alternate memory for any buffer that might leak into a different
|
|
||||||
// computation (e.g., while body). Enable more usage of alternate memory
|
|
||||||
// across computations.
|
|
||||||
if (defining_position.instruction->parent() != use.instruction->parent() ||
|
|
||||||
(!use.instruction->called_computations().empty() &&
|
|
||||||
use.instruction->opcode() != HloOpcode::kFusion)) {
|
|
||||||
VLOG(3) << "Use is in a different computation or calls a computation.";
|
|
||||||
// Fail because we do not allow asynchronous copies while in the bodies of
|
|
||||||
// other computation.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// First try keeping the allocation entirely in the alternate memory.
|
// First try keeping the allocation entirely in the alternate memory.
|
||||||
if (!definition_requires_buffer_in_default_mem &&
|
if (!definition_requires_buffer_in_default_mem &&
|
||||||
!use_requires_buffer_in_default_mem &&
|
!use_requires_buffer_in_default_mem &&
|
||||||
@ -491,7 +523,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
prev_allocation->end_time())) {
|
prev_allocation->end_time())) {
|
||||||
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
|
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
|
||||||
prev_allocation->start_time(), prev_allocation->end_time(),
|
prev_allocation->start_time(), prev_allocation->end_time(),
|
||||||
allocations);
|
prev_allocation->end_time(), allocations);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
VLOG(3) << "This violates the maximum async copies.";
|
VLOG(3) << "This violates the maximum async copies.";
|
||||||
@ -504,7 +536,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) {
|
if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) {
|
||||||
VLOG(3) << "Eviction successful.";
|
VLOG(3) << "Eviction successful.";
|
||||||
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
|
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
|
||||||
time, time, allocations);
|
time, time, time, allocations);
|
||||||
eviction_scheduled = true;
|
eviction_scheduled = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -558,7 +590,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
// ^ ^
|
// ^ ^
|
||||||
// Copy Copy
|
// Copy Copy
|
||||||
// Start Done
|
// Start Done
|
||||||
options_.prefetch_interval_picker->Begin(use, start_time, end_time);
|
options_.prefetch_interval_picker->Begin(use, start_time,
|
||||||
|
latest_prefetch_time);
|
||||||
while (!options_.prefetch_interval_picker->Done()) {
|
while (!options_.prefetch_interval_picker->Done()) {
|
||||||
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
||||||
VLOG(4) << "Trying alternate memory allocation ("
|
VLOG(4) << "Trying alternate memory allocation ("
|
||||||
@ -583,7 +616,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
|
|
||||||
AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate,
|
AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate,
|
||||||
chunk_candidate.chunk, alternate_mem_interval.start,
|
chunk_candidate.chunk, alternate_mem_interval.start,
|
||||||
end_time, allocations);
|
end_time, latest_prefetch_time, allocations);
|
||||||
|
|
||||||
allocations->back()->AddUse(use);
|
allocations->back()->AddUse(use);
|
||||||
return true;
|
return true;
|
||||||
@ -598,16 +631,19 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
|||||||
void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
||||||
const MemorySpaceAssignment::Allocation& prev_allocation,
|
const MemorySpaceAssignment::Allocation& prev_allocation,
|
||||||
MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time,
|
MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time,
|
||||||
|
int64 copy_done_schedule_before_time,
|
||||||
MemorySpaceAssignment::AllocationSequence* allocations) {
|
MemorySpaceAssignment::AllocationSequence* allocations) {
|
||||||
VLOG(3) << "Copy to "
|
VLOG(3) << "Copy to "
|
||||||
<< (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
|
<< (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
|
||||||
? "default"
|
? "default"
|
||||||
: "alternate")
|
: "alternate")
|
||||||
<< " memory between " << start_time << " and " << end_time;
|
<< " memory between " << start_time << " and "
|
||||||
|
<< copy_done_schedule_before_time << " keeping until " << end_time;
|
||||||
|
|
||||||
allocations->push_back(
|
allocations->push_back(
|
||||||
absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
|
absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
|
||||||
prev_allocation, memory_space, chunk, start_time, end_time));
|
prev_allocation, memory_space, chunk, start_time, end_time,
|
||||||
|
copy_done_schedule_before_time));
|
||||||
|
|
||||||
// Register the additional async copy with the interval tree to keep track of
|
// Register the additional async copy with the interval tree to keep track of
|
||||||
// the limit at any given time.
|
// the limit at any given time.
|
||||||
@ -828,9 +864,12 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) {
|
|||||||
&memory_space_assignment.allocation_map_, options, *alias_analysis,
|
&memory_space_assignment.allocation_map_, options, *alias_analysis,
|
||||||
*hlo_live_range);
|
*hlo_live_range);
|
||||||
|
|
||||||
|
HeapSimulator::Options heap_simulator_options;
|
||||||
|
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
||||||
module->schedule(),
|
module->schedule(),
|
||||||
*alias_analysis.get(), options.size_fn)
|
*alias_analysis.get(), options.size_fn,
|
||||||
|
heap_simulator_options)
|
||||||
.status());
|
.status());
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
||||||
@ -1221,28 +1260,30 @@ Status MemorySpaceAssignment::FixSchedule() {
|
|||||||
instruction_index <
|
instruction_index <
|
||||||
flattened_instruction_sequence_.instructions().size();
|
flattened_instruction_sequence_.instructions().size();
|
||||||
++instruction_index) {
|
++instruction_index) {
|
||||||
HloInstruction* instruction =
|
|
||||||
flattened_instruction_sequence_.instructions()[instruction_index];
|
|
||||||
if (instruction->parent() != computation) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto insts_before_iter = schedule_before_.find(instruction_index);
|
auto insts_before_iter = schedule_before_.find(instruction_index);
|
||||||
if (insts_before_iter != schedule_before_.end()) {
|
if (insts_before_iter != schedule_before_.end()) {
|
||||||
for (HloInstruction* new_instruction : insts_before_iter->second) {
|
for (HloInstruction* new_instruction : insts_before_iter->second) {
|
||||||
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
|
if (new_instruction->parent() == computation) {
|
||||||
&inserted_instructions);
|
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
|
||||||
|
&inserted_instructions);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
HloInstruction* instruction =
|
||||||
|
flattened_instruction_sequence_.instructions()[instruction_index];
|
||||||
// Insert only if not previously inserted.
|
// Insert only if not previously inserted.
|
||||||
if (!inserted_instructions.contains(instruction)) {
|
if (!inserted_instructions.contains(instruction) &&
|
||||||
|
instruction->parent() == computation) {
|
||||||
EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
|
EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
|
||||||
&inserted_instructions);
|
&inserted_instructions);
|
||||||
}
|
}
|
||||||
auto insts_after_iter = schedule_after_.find(instruction_index);
|
auto insts_after_iter = schedule_after_.find(instruction_index);
|
||||||
if (insts_after_iter != schedule_after_.end()) {
|
if (insts_after_iter != schedule_after_.end()) {
|
||||||
for (HloInstruction* new_instruction : insts_after_iter->second) {
|
for (HloInstruction* new_instruction : insts_after_iter->second) {
|
||||||
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
|
if (new_instruction->parent() == computation) {
|
||||||
&inserted_instructions);
|
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
|
||||||
|
&inserted_instructions);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -268,6 +268,10 @@ class MemorySpaceAssignment {
|
|||||||
// Specifies the upper bound for number of outstanding asynchronous copies,
|
// Specifies the upper bound for number of outstanding asynchronous copies,
|
||||||
// -1 for unlimited.
|
// -1 for unlimited.
|
||||||
int64 max_outstanding_async_copies = -1;
|
int64 max_outstanding_async_copies = -1;
|
||||||
|
|
||||||
|
// If true, tries allocating buffers across (e.g., before and inside a while
|
||||||
|
// loop body) sequential calls (kWhile, kCall, and kConditional).
|
||||||
|
bool allocate_across_sequential_calls = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This class represents an allocation that might either be in the default or
|
// This class represents an allocation that might either be in the default or
|
||||||
@ -363,13 +367,14 @@ class MemorySpaceAssignment {
|
|||||||
class CopyAllocation : public Allocation {
|
class CopyAllocation : public Allocation {
|
||||||
public:
|
public:
|
||||||
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
|
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
|
||||||
Chunk chunk, int64 start_time, int64 end_time)
|
Chunk chunk, int64 start_time, int64 end_time,
|
||||||
|
int64 copy_done_schedule_before_time)
|
||||||
: Allocation(/*instruction=*/nullptr,
|
: Allocation(/*instruction=*/nullptr,
|
||||||
/*defining_position=*/{nullptr, {}}, memory_space, chunk,
|
/*defining_position=*/{nullptr, {}}, memory_space, chunk,
|
||||||
start_time, end_time),
|
start_time, end_time),
|
||||||
prev_allocation_(prev_allocation),
|
prev_allocation_(prev_allocation),
|
||||||
copy_start_schedule_after_(start_time),
|
copy_start_schedule_after_(start_time),
|
||||||
copy_done_schedule_before_(end_time) {}
|
copy_done_schedule_before_(copy_done_schedule_before_time) {}
|
||||||
|
|
||||||
bool is_copy_allocation() const override { return true; }
|
bool is_copy_allocation() const override { return true; }
|
||||||
|
|
||||||
@ -525,8 +530,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
// allocations can be in default or alternate memory spaces, or can be
|
// allocations can be in default or alternate memory spaces, or can be
|
||||||
// prefetches or evictions. Returns true if successful.
|
// prefetches or evictions. Returns true if successful.
|
||||||
bool FindAllocation(int64 start_time, int64 end_time, int64 last_use_time,
|
bool FindAllocation(int64 start_time, int64 end_time, int64 last_use_time,
|
||||||
HloPosition defining_position, HloUse use,
|
int64 latest_prefetch_time, HloPosition defining_position,
|
||||||
const HloValue* buffer, int64 size,
|
HloUse use, const HloValue* buffer, int64 size,
|
||||||
MemorySpaceAssignment::AllocationSequence* allocations);
|
MemorySpaceAssignment::AllocationSequence* allocations);
|
||||||
|
|
||||||
// Try allocating in alternate memory without any copies. Returns true if
|
// Try allocating in alternate memory without any copies. Returns true if
|
||||||
@ -560,7 +565,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
// Adds an asynchronous copy to the allocations.
|
// Adds an asynchronous copy to the allocations.
|
||||||
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
||||||
MemorySpace memory_space, Chunk chunk, int64 start_time,
|
MemorySpace memory_space, Chunk chunk, int64 start_time,
|
||||||
int64 end_time,
|
int64 end_time, int64 copy_done_schedule_before_time,
|
||||||
MemorySpaceAssignment::AllocationSequence* allocations);
|
MemorySpaceAssignment::AllocationSequence* allocations);
|
||||||
|
|
||||||
// These methods are used for delaying committing the chunk candidate until
|
// These methods are used for delaying committing the chunk candidate until
|
||||||
|
@ -35,7 +35,8 @@ int64 ShapeSize(const Shape& shape) {
|
|||||||
return ShapeUtil::ByteSizeOf(shape, kPointerSize);
|
return ShapeUtil::ByteSizeOf(shape, kPointerSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
class MemorySpaceAssignmentTest : public HloTestBase {
|
class MemorySpaceAssignmentTest : public HloTestBase,
|
||||||
|
public ::testing::WithParamInterface<bool> {
|
||||||
protected:
|
protected:
|
||||||
// We use the following two memory space values to describe the default (slow
|
// We use the following two memory space values to describe the default (slow
|
||||||
// and large) and alternate (fast and small) memory spaces.
|
// and large) and alternate (fast and small) memory spaces.
|
||||||
@ -105,6 +106,7 @@ class MemorySpaceAssignmentTest : public HloTestBase {
|
|||||||
options.size_fn = size_fn;
|
options.size_fn = size_fn;
|
||||||
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
|
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
|
||||||
options.max_outstanding_async_copies = max_outstanding_async_copies;
|
options.max_outstanding_async_copies = max_outstanding_async_copies;
|
||||||
|
options.allocate_across_sequential_calls = GetParam();
|
||||||
std::unique_ptr<PresetAssignments> preset_assignments =
|
std::unique_ptr<PresetAssignments> preset_assignments =
|
||||||
MemorySpaceAssignment::Run(module, options).ValueOrDie();
|
MemorySpaceAssignment::Run(module, options).ValueOrDie();
|
||||||
CheckPresetAssignments(preset_assignments.get());
|
CheckPresetAssignments(preset_assignments.get());
|
||||||
@ -190,7 +192,7 @@ class MemorySpaceAssignmentTest : public HloTestBase {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, ParameterOnly) {
|
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
|
||||||
// A module consisting of a single parameter. Inputs/outputs are currently
|
// A module consisting of a single parameter. Inputs/outputs are currently
|
||||||
// excluded from memory space assignment.
|
// excluded from memory space assignment.
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
@ -210,7 +212,7 @@ TEST_F(MemorySpaceAssignmentTest, ParameterOnly) {
|
|||||||
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
|
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, Simple) {
|
TEST_P(MemorySpaceAssignmentTest, Simple) {
|
||||||
// A simple module with a few simple instructions. Expect this to be
|
// A simple module with a few simple instructions. Expect this to be
|
||||||
// transformed with CopyStart and CopyDone instructions inserted after inputs
|
// transformed with CopyStart and CopyDone instructions inserted after inputs
|
||||||
// and before outputs.
|
// and before outputs.
|
||||||
@ -256,7 +258,7 @@ TEST_F(MemorySpaceAssignmentTest, Simple) {
|
|||||||
preset_assignments->chunks()[1].second.offset);
|
preset_assignments->chunks()[1].second.offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, NegateChain) {
|
TEST_P(MemorySpaceAssignmentTest, NegateChain) {
|
||||||
// The negate chain is long enough for asynchronous copy to be inserted
|
// The negate chain is long enough for asynchronous copy to be inserted
|
||||||
// between p1 and add.
|
// between p1 and add.
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
@ -319,7 +321,7 @@ TEST_F(MemorySpaceAssignmentTest, NegateChain) {
|
|||||||
EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
|
EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) {
|
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) {
|
||||||
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
||||||
|
|
||||||
AssignMemorySpace(module.get());
|
AssignMemorySpace(module.get());
|
||||||
@ -330,12 +332,9 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) {
|
|||||||
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
|
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
|
||||||
op::AsyncCopy(kDefaultMemorySpace,
|
op::AsyncCopy(kDefaultMemorySpace,
|
||||||
kAlternateMemorySpace, op::Tanh()))));
|
kAlternateMemorySpace, op::Tanh()))));
|
||||||
|
|
||||||
EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module),
|
|
||||||
2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
|
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
|
||||||
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
||||||
|
|
||||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
|
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
|
||||||
@ -344,7 +343,7 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
|
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
|
||||||
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
||||||
|
|
||||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
|
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
|
||||||
@ -353,7 +352,16 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
|
|||||||
1);
|
1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, While) {
|
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) {
|
||||||
|
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
|
||||||
|
|
||||||
|
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2);
|
||||||
|
|
||||||
|
EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module),
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, While) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||||
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||||
@ -429,14 +437,18 @@ TEST_F(MemorySpaceAssignmentTest, While) {
|
|||||||
AssignMemorySpace(module.get());
|
AssignMemorySpace(module.get());
|
||||||
|
|
||||||
// Ensure the tuple value and buffers used in the while instruction are
|
// Ensure the tuple value and buffers used in the while instruction are
|
||||||
// exempted from using the alternate memory. However, body_data_mul is
|
// exempted from using the alternate memory when allocating across sequential
|
||||||
// independent and can be safely be placed in the alternate memory.
|
// calls is disabled. However, body_data_mul is independent and can be safely
|
||||||
EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
|
// be placed in the alternate memory.
|
||||||
EXPECT_THAT(data, op::ShapeWithLayout(shape));
|
const bool allocate_across_sequential_calls = GetParam();
|
||||||
EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
|
if (!allocate_across_sequential_calls) {
|
||||||
EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
|
EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
|
||||||
EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
|
EXPECT_THAT(data, op::ShapeWithLayout(shape));
|
||||||
EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
|
EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
|
||||||
|
EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
|
||||||
|
EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
|
||||||
|
EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
|
||||||
|
}
|
||||||
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
|
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
|
||||||
F32, {2, 3},
|
F32, {2, 3},
|
||||||
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
@ -444,7 +456,7 @@ TEST_F(MemorySpaceAssignmentTest, While) {
|
|||||||
EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
|
EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, Tuple) {
|
TEST_P(MemorySpaceAssignmentTest, Tuple) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
|
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
|
||||||
@ -499,7 +511,7 @@ TEST_F(MemorySpaceAssignmentTest, Tuple) {
|
|||||||
op::GetTupleElement(op::GetTupleElement()))));
|
op::GetTupleElement(op::GetTupleElement()))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, Bitcast) {
|
TEST_P(MemorySpaceAssignmentTest, Bitcast) {
|
||||||
// Bitcasts can cause the position in the alternate memory to appear multiple
|
// Bitcasts can cause the position in the alternate memory to appear multiple
|
||||||
// times in the preset assignments. This test ensure the preset assignments
|
// times in the preset assignments. This test ensure the preset assignments
|
||||||
// refer to unique positions.
|
// refer to unique positions.
|
||||||
@ -528,7 +540,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast) {
|
|||||||
EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
|
EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, Bitcast2) {
|
TEST_P(MemorySpaceAssignmentTest, Bitcast2) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape param_shape = ShapeUtil::MakeShape(F32, {6});
|
Shape param_shape = ShapeUtil::MakeShape(F32, {6});
|
||||||
@ -564,7 +576,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast2) {
|
|||||||
EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
|
EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, Bitcast3) {
|
TEST_P(MemorySpaceAssignmentTest, Bitcast3) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
|
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||||
@ -627,7 +639,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast3) {
|
|||||||
EXPECT_EQ(bitcast4->shape().layout().memory_space(), kAlternateMemorySpace);
|
EXPECT_EQ(bitcast4->shape().layout().memory_space(), kAlternateMemorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, BitcastTuple) {
|
TEST_P(MemorySpaceAssignmentTest, BitcastTuple) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape param_shape = ShapeUtil::MakeShape(F32, {6});
|
Shape param_shape = ShapeUtil::MakeShape(F32, {6});
|
||||||
@ -678,7 +690,7 @@ TEST_F(MemorySpaceAssignmentTest, BitcastTuple) {
|
|||||||
AssignMemorySpace(module.get());
|
AssignMemorySpace(module.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, LastUseOpt) {
|
TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
|
||||||
// Test that checks the last use optimization. It uses two buffers that should
|
// Test that checks the last use optimization. It uses two buffers that should
|
||||||
// be placed in alternate memory.
|
// be placed in alternate memory.
|
||||||
//
|
//
|
||||||
@ -735,7 +747,7 @@ TEST_F(MemorySpaceAssignmentTest, LastUseOpt) {
|
|||||||
op::Add(op::Parameter(0), op::Parameter(0)))));
|
op::Add(op::Parameter(0), op::Parameter(0)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, CopyOrdering) {
|
TEST_P(MemorySpaceAssignmentTest, CopyOrdering) {
|
||||||
// Test to make sure the CopyStarts follow the same CopyDone order. The shapes
|
// Test to make sure the CopyStarts follow the same CopyDone order. The shapes
|
||||||
// are picked in increasing order to exploit the fact that heap simulator
|
// are picked in increasing order to exploit the fact that heap simulator
|
||||||
// processes larger tensors first. This checks the ability of the compiler to
|
// processes larger tensors first. This checks the ability of the compiler to
|
||||||
@ -850,7 +862,7 @@ TEST_F(MemorySpaceAssignmentTest, CopyOrdering) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
|
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
|
||||||
// Test to ensure CopyStart/CopyDone is placed only in the entry computation.
|
// Test to ensure CopyStart/CopyDone is placed only in the entry computation.
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||||
@ -934,7 +946,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
|
|||||||
AssignMemorySpace(module.get(), -1, 50);
|
AssignMemorySpace(module.get(), -1, 50);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
|
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||||
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
|
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
|
||||||
@ -1005,7 +1017,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
|
|||||||
AssignMemorySpace(module.get(), -1, 5);
|
AssignMemorySpace(module.get(), -1, 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
|
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||||
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
|
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
|
||||||
@ -1071,7 +1083,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
|
|||||||
AssignMemorySpace(module.get(), -1, 5);
|
AssignMemorySpace(module.get(), -1, 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
|
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||||
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
|
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
|
||||||
@ -1144,7 +1156,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
|
|||||||
AssignMemorySpace(module.get(), -1, 5);
|
AssignMemorySpace(module.get(), -1, 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
|
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
|
||||||
// This test reproduces the failure in b/143288178. Given a graph like the
|
// This test reproduces the failure in b/143288178. Given a graph like the
|
||||||
// following:
|
// following:
|
||||||
//
|
//
|
||||||
@ -1242,7 +1254,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
|
|||||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||||
tuple_shape, cond_computation, body_computation, tuple));
|
tuple_shape, cond_computation, body_computation, tuple));
|
||||||
HloInstruction* while_data = builder.AddInstruction(
|
HloInstruction* while_data = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(shape, while_op, 0));
|
HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1));
|
||||||
HloInstruction* root =
|
HloInstruction* root =
|
||||||
builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
|
builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
|
||||||
HloComputation* entry_computation =
|
HloComputation* entry_computation =
|
||||||
@ -1265,7 +1277,143 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
|
|||||||
AssignMemorySpace(module.get(), -1, 20);
|
AssignMemorySpace(module.get(), -1, 20);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
|
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
|
||||||
|
auto module = CreateNewVerifiedModule();
|
||||||
|
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
|
||||||
|
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||||
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape});
|
||||||
|
|
||||||
|
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||||
|
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
|
||||||
|
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||||
|
HloInstruction* cond_limit = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
|
||||||
|
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||||
|
cond_limit, ComparisonDirection::kLt));
|
||||||
|
HloComputation* cond_computation =
|
||||||
|
module->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
|
auto body_builder = HloComputation::Builder("WhileBody");
|
||||||
|
HloInstruction* body_param = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
|
||||||
|
HloInstruction* body_iter = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
|
||||||
|
HloInstruction* body_data = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
|
||||||
|
HloInstruction* body_negate0 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data));
|
||||||
|
HloInstruction* body_negate1 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0));
|
||||||
|
HloInstruction* body_negate2 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1));
|
||||||
|
HloInstruction* body_negate3 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2));
|
||||||
|
HloInstruction* body_negate4 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3));
|
||||||
|
HloInstruction* body_negate5 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4));
|
||||||
|
HloInstruction* body_negate6 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5));
|
||||||
|
HloInstruction* body_negate7 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6));
|
||||||
|
HloInstruction* body_iter_increment = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
|
||||||
|
HloInstruction* body_iter_next =
|
||||||
|
body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
|
||||||
|
HloInstruction* body_out = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7}));
|
||||||
|
HloComputation* body_computation =
|
||||||
|
module->AddEmbeddedComputation(body_builder.Build());
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
HloInstruction* data = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, shape, "param_data"));
|
||||||
|
HloInstruction* iter = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
|
||||||
|
HloInstruction* negate0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
|
||||||
|
HloInstruction* negate1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
|
||||||
|
HloInstruction* negate2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
|
||||||
|
HloInstruction* negate3 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
|
||||||
|
HloInstruction* negate4 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
|
||||||
|
HloInstruction* negate5 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
|
||||||
|
HloInstruction* negate6 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
|
||||||
|
HloInstruction* negate7 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
|
||||||
|
HloInstruction* tuple = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateTuple({data, iter, negate7}));
|
||||||
|
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||||
|
tuple_shape, cond_computation, body_computation, tuple));
|
||||||
|
HloInstruction* while_data = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(shape, while_op, 0));
|
||||||
|
HloInstruction* while_data2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(shape, while_op, 2));
|
||||||
|
HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
shape, HloOpcode::kAdd, while_data, while_data2));
|
||||||
|
HloComputation* entry_computation =
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
HloSchedule schedule(module.get());
|
||||||
|
schedule.set_sequence(cond_computation,
|
||||||
|
{cond_param, cond_iter, cond_limit, cond_lt});
|
||||||
|
schedule.set_sequence(
|
||||||
|
body_computation,
|
||||||
|
{body_param, body_iter, body_data, body_negate0, body_negate1,
|
||||||
|
body_negate2, body_negate3, body_negate4, body_negate5, body_negate6,
|
||||||
|
body_negate7, body_iter_increment, body_iter_next, body_out});
|
||||||
|
schedule.set_sequence(
|
||||||
|
entry_computation,
|
||||||
|
{iter, data, negate0, negate1, negate2, negate3, negate4, negate5,
|
||||||
|
negate6, negate7, tuple, while_op, while_data, while_data2, root});
|
||||||
|
TF_CHECK_OK(module->set_schedule(schedule));
|
||||||
|
|
||||||
|
// Pick a large max prefetch interval to ensure all the while inputs are
|
||||||
|
// allocated in the alternate memory.
|
||||||
|
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||||
|
/*max_prefetch_interval=*/25);
|
||||||
|
|
||||||
|
int64 memory_space_across_while = kDefaultMemorySpace;
|
||||||
|
bool allocate_across_sequential_calls = GetParam();
|
||||||
|
if (allocate_across_sequential_calls) {
|
||||||
|
memory_space_across_while = kAlternateMemorySpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Index {0} of the while loop argument is not written inside the while loop,
|
||||||
|
// so it can be trivially placed in the alternate memory space.
|
||||||
|
*ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
|
||||||
|
LayoutUtil::MakeLayout(
|
||||||
|
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
|
kAlternateMemorySpace);
|
||||||
|
// Indexes {1} and {2} of the while loop argument are only placed in the
|
||||||
|
// alternate memory if we enable the allocate_across_sequential_calls option.
|
||||||
|
*ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() =
|
||||||
|
LayoutUtil::MakeLayout(
|
||||||
|
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
|
memory_space_across_while);
|
||||||
|
*ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
|
||||||
|
LayoutUtil::MakeLayout(
|
||||||
|
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
|
memory_space_across_while);
|
||||||
|
|
||||||
|
// Expect the layout for the while loop and its aliased buffers.
|
||||||
|
EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
|
||||||
|
EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape));
|
||||||
|
EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape));
|
||||||
|
EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape));
|
||||||
|
EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, DanglingCopy) {
|
||||||
// This situation was encountered in vss, where there is a mismatch in the
|
// This situation was encountered in vss, where there is a mismatch in the
|
||||||
// memory space in preset assignments and the output graph.
|
// memory space in preset assignments and the output graph.
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
@ -1311,7 +1459,7 @@ TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
|
|||||||
AssignMemorySpace(module.get());
|
AssignMemorySpace(module.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) {
|
TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||||
@ -1348,7 +1496,7 @@ TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) {
|
|||||||
AssignMemorySpace(module.get());
|
AssignMemorySpace(module.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, TupleInput) {
|
TEST_P(MemorySpaceAssignmentTest, TupleInput) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||||
@ -1388,7 +1536,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleInput) {
|
|||||||
AssignMemorySpace(module.get());
|
AssignMemorySpace(module.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, TupleToTuple1) {
|
TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||||
@ -1467,7 +1615,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple1) {
|
|||||||
op::GetTupleElement(op::Fusion(), 1)))));
|
op::GetTupleElement(op::Fusion(), 1)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, TupleToTuple2) {
|
TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||||
@ -1547,7 +1695,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple2) {
|
|||||||
op::GetTupleElement(op::Fusion(), 1), 1))))));
|
op::GetTupleElement(op::Fusion(), 1), 1))))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, TupleToTuple3) {
|
TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||||
@ -1594,7 +1742,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple3) {
|
|||||||
EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
|
EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) {
|
TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
|
||||||
@ -1649,7 +1797,7 @@ TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) {
|
|||||||
kDefaultMemorySpace);
|
kDefaultMemorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, CostAnalysis) {
|
TEST_P(MemorySpaceAssignmentTest, CostAnalysis) {
|
||||||
// This is mostly a smoke test since it's difficult and brittle to work out
|
// This is mostly a smoke test since it's difficult and brittle to work out
|
||||||
// the cost of the HLO instructions.
|
// the cost of the HLO instructions.
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
@ -1701,7 +1849,7 @@ TEST_F(MemorySpaceAssignmentTest, CostAnalysis) {
|
|||||||
EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
|
EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
|
TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
|
||||||
// This test is carefully crafted to force only negates to be allocated to the
|
// This test is carefully crafted to force only negates to be allocated to the
|
||||||
// alternate memory. The graph consists of interleaving negate and tanh
|
// alternate memory. The graph consists of interleaving negate and tanh
|
||||||
// operations:
|
// operations:
|
||||||
@ -1762,16 +1910,16 @@ TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
|
|||||||
F32, {4, 6},
|
F32, {4, 6},
|
||||||
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
kDefaultMemorySpace);
|
kDefaultMemorySpace);
|
||||||
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
|
// Expect only negates to be in alternate memory space. Not all might fit but
|
||||||
F32, {4, 6},
|
// make sure at least one does.
|
||||||
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2,
|
||||||
kAlternateMemorySpace);
|
negate3, negate4};
|
||||||
// Expect only negates to be in alternate memory space.
|
int64 num_negates_in_alternate_mem = absl::c_count_if(
|
||||||
EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
|
negate_instructions, [&](const HloInstruction* instruction) {
|
||||||
EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
|
return instruction->shape().layout().memory_space() ==
|
||||||
EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
|
kAlternateMemorySpace;
|
||||||
EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
|
});
|
||||||
EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
|
EXPECT_GE(num_negates_in_alternate_mem, 1);
|
||||||
EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
|
EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
|
||||||
EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
|
EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
|
||||||
EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
|
EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
|
||||||
@ -1779,5 +1927,9 @@ TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
|
|||||||
EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
|
EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
||||||
|
MemorySpaceAssignmentTest,
|
||||||
|
::testing::Values(false, true));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user