[XLA] Better alias handling in memory space assignment.

Instead of using ad-hoc alias rules (for kWhile and kConditional), we use the
aliases reported by HloAliasAnalysis. Using this, we can ensure aliased values
get the same allocation. In practice, this enables us to share the buffer of
DynamicUpdateSlice in a while loop in alternate memory. For sharing DUS buffers
that are not in while loops, we need to make changes to HloDataflowAnalysis and
copy insertion.

PiperOrigin-RevId: 315303035
Change-Id: I5f1057ed7df2b1f09138512be248cdc09533f54f
This commit is contained in:
Berkin Ilbeyi 2020-06-08 10:30:16 -07:00 committed by TensorFlower Gardener
parent d2e0f75817
commit 60d63428b1
3 changed files with 256 additions and 118 deletions

View File

@ -432,8 +432,8 @@ std::string MemorySpaceAssignment::AllocationValue::ToString() const {
absl::StrAppend(&out, "\n position:\n");
absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
absl::StrAppend(&out, " uses:\n");
for (const HloUse& use : uses_) {
absl::StrAppend(&out, " ", use.ToString(), "\n");
for (const Use& use : uses_) {
absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
}
return out;
}
@ -515,6 +515,53 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues(
}
}
void AlternateMemoryBestFitHeap::FindAliases(
std::vector<AllocationValue>* allocation_values) const {
absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
values_by_defining_inst;
for (AllocationValue& value : *allocation_values) {
CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
values_by_defining_inst[value.defining_instruction()] = &value;
}
auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
AllocationValue::Use* use) {
auto aliased_value_it = values_by_defining_inst.find(instruction);
if (aliased_value_it != values_by_defining_inst.end()) {
VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to "
<< aliased_value_it->second->ToShortString();
use->aliases.push_back(aliased_value_it->second->defining_position());
}
};
for (AllocationValue& value : *allocation_values) {
for (AllocationValue::Use& use : value.uses()) {
// Find any aliases with the instruction itself (operand and output must
// alias).
maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
// Find any aliases with the parameters of called computations.
for (const HloComputation* called_computation :
use.hlo_use.instruction->called_computations()) {
for (const HloInstruction* parameter_instruction :
called_computation->parameter_instructions()) {
maybe_add_alias_with_instruction(parameter_instruction, &use);
}
}
// Special case for kWhile: the root of the body computation must alias as
// well.
if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
HloPosition root_alias{
use.hlo_use.instruction->while_body()->root_instruction(),
use.hlo_use.operand_index};
VLOG(3) << "Adding while body root aliasing for use "
<< use.hlo_use.ToString() << " to " << root_alias;
use.aliases.push_back(root_alias);
}
}
}
}
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
@ -675,18 +722,18 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
// multiple called computations), determine if the parameter->first use
// dependency is short.
int64 conditional_time = instruction_schedule.at(use.instruction);
for (const HloUse& other_use : value.uses()) {
if (other_use.instruction != use.instruction) {
for (const AllocationValue::Use& other_use : value.uses()) {
if (other_use.hlo_use.instruction != use.instruction) {
continue;
}
HloComputation* called_computation =
use.instruction->called_computations().at(other_use.operand_number -
1);
use.instruction->called_computations().at(
other_use.hlo_use.operand_number - 1);
const HloInstruction* parameter_instruction =
called_computation->parameter_instruction(0);
HloValue* parameter_value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
parameter_instruction, other_use.operand_index);
parameter_instruction, other_use.hlo_use.operand_index);
int64 parameter_time = instruction_schedule.at(parameter_instruction);
int64 min_use_time = conditional_time;
for (const HloUse& parameter_use : parameter_value->uses()) {
@ -947,6 +994,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
for (const auto& colocated_interval : colocated_intervals) {
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
}
FindAliases(&allocation_values);
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
// Data structure to contain the preferred offset for a given computation.
@ -969,25 +1017,26 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// Iterate over the uses.
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
const HloUse& use = allocation_value.uses().at(use_idx);
int64 use_time = instruction_schedule.at(use.instruction);
const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
const HloUse hlo_use = use.hlo_use;
int64 use_time = instruction_schedule.at(hlo_use.instruction);
int64 latest_prefetch_time = use_time;
bool allow_no_copy_alternate_mem_allocation = true;
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
// Sequential calls include kWhile, kCall, and kConditional opcodes.
bool is_sequential_call =
(GetInstructionCallContext(use.instruction->opcode()) ==
(GetInstructionCallContext(hlo_use.instruction->opcode()) ==
CallContext::kSequential);
if (is_sequential_call) {
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
hlo_use.instruction->called_computations()) {
const HloLiveRange::TimeBound& computation_span =
hlo_live_range_.computation_span_times().at(called_computation);
latest_prefetch_time =
std::min(computation_span.start, latest_prefetch_time);
}
if (use.instruction->opcode() == HloOpcode::kWhile) {
if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
// Given an example while loop and flattened schedule (logical times
// shown on the left):
//
@ -1008,10 +1057,10 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// the interval to time 0-4. This is so that the remaining interval
// (5-6) can be allocated separately and this buffer doesn't waste
// alternate memory space within the while loop body.
HloComputation* while_body = use.instruction->while_body();
HloComputation* while_body = hlo_use.instruction->while_body();
// We require while body ROOTs to be the last in the schedule.
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
instruction_schedule.at(use.instruction))
instruction_schedule.at(hlo_use.instruction))
<< "While body ROOTs need to be the last in the schedule! "
"Please run RootInstructionSinker.";
// Replace the use time with the parameter time so that we can decide
@ -1019,11 +1068,11 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// look at uses within the while loop body.
use_time =
instruction_schedule.at(while_body->parameter_instruction(0));
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
} else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
// Replace the use time with the earliest parameter of called
// computations.
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
hlo_use.instruction->called_computations()) {
use_time = std::min(
use_time, instruction_schedule.at(
called_computation->parameter_instruction(0)));
@ -1033,8 +1082,8 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// Add a required assignment in default memory if the use not allowed in
// alternate memory.
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
AddRequiredAssignment(allocation_value.value(), use.instruction,
if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
MemorySpace::kDefault, use_time);
} else if (use_idx > 0) {
// We allow buffers in alternate memory that are passed into
@ -1043,14 +1092,16 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// alternate memory allocation, subsequent uses cannot use the same
// alternate memory allocation in order not to clobber data. So we force
// default memory allocation for these subsequent uses.
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
previous_use.instruction != use.instruction) {
const AllocationValue::Use& previous_use =
allocation_value.uses().at(use_idx - 1);
if (previous_use.hlo_use.instruction->opcode() ==
HloOpcode::kConditional &&
previous_use.hlo_use.instruction != hlo_use.instruction) {
allow_no_copy_alternate_mem_allocation = false;
earliest_prefetch_time =
instruction_schedule.at(previous_use.instruction);
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
<< use.ToString()
instruction_schedule.at(previous_use.hlo_use.instruction);
VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
<< ") of use (" << hlo_use.ToString()
<< ") is a conditional, so this use will need to evict. "
<< "Earliest prefetch time = " << *earliest_prefetch_time;
}
@ -1059,7 +1110,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// Bitcasts don't define buffers and don't directly consume buffers. Skip
// allocating buffers for bitcast uses. The uses that feed from bitcasts
// will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) {
AllocationRequest request;
// Rarely, (e.g., when conditional true and false parameters are the
// same), definition time can be the time of the conditional and use
@ -1072,7 +1123,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
allow_no_copy_alternate_mem_allocation;
request.earliest_prefetch_time = earliest_prefetch_time;
request.preferred_offset = preferred_offset;
request.use = use;
request.use = &use;
request.allocation_value = &allocation_value;
if (!AllocateSegment(request)) {
// If the allocation finding failed (e.g., due to running out of
@ -1085,23 +1136,25 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
// If there are multiple uses, they can try using the memory allocation
// already at the alternate memory.
definition_time = instruction_schedule.at(use.instruction);
definition_time = instruction_schedule.at(hlo_use.instruction);
}
// If the use has been a sequential call (e.g. a while loop), the other
// colocated intervals must alias with this allocation.
if (is_sequential_call) {
MemorySpaceAssignment::Allocation* aliased_allocation =
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
use_time);
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
// Remember the preferred offset to be used inside while loop body
// computations.
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
use.instruction->opcode() == HloOpcode::kWhile) {
preferred_offset_for_computation[use.instruction->while_body()] =
aliased_allocation->chunk().offset;
}
// Propagate the allocation to any aliases this use might have had.
MemorySpaceAssignment::Allocation* aliased_allocation =
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
use_time);
for (const HloPosition& aliased_position : use.aliases) {
AddAliasedRequiredAssignment(aliased_position.instruction,
aliased_position.index,
aliased_allocation);
}
// Special case for while loops since the root offset must agree with
// other offsets: remember the preferred offset for the while loop body.
if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
aliased_allocation->memory_space() == MemorySpace::kAlternate) {
preferred_offset_for_computation[hlo_use.instruction->while_body()] =
aliased_allocation->chunk().offset;
}
}
if (!allocation_success) {
@ -1212,34 +1265,45 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
pending_required_assignments_.clear();
}
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
const HloUse& use,
const MemorySpaceAssignment::Allocation* aliased_allocation) {
// Add aliased required assignments.
if (use.instruction->opcode() == HloOpcode::kWhile) {
HloComputation* while_body = use.instruction->while_body();
HloComputation* while_condition = use.instruction->while_condition();
AddAliasedRequiredAssignment(while_condition->parameter_instruction(0),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(while_body->parameter_instruction(0),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(while_body->root_instruction(),
use.operand_index, aliased_allocation);
AddAliasedRequiredAssignment(use.instruction, use.operand_index,
aliased_allocation);
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
HloComputation* called_computation =
use.instruction->called_computations().at(use.operand_number - 1);
AddAliasedRequiredAssignment(called_computation->parameter_instruction(0),
use.operand_index, aliased_allocation);
} else {
CHECK(use.instruction->opcode() == HloOpcode::kCall);
HloComputation* called_computation =
use.instruction->called_computations().at(0);
AddAliasedRequiredAssignment(
called_computation->parameter_instruction(use.operand_number),
use.operand_index, aliased_allocation);
absl::optional<RequiredMemoryAssignment>
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
int64 time) const {
auto required_assignment_it = required_assignments_.find(buffer);
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
if (required_assignment_it != required_assignments_.end()) {
for (const RequiredMemoryAssignment& required_assignment :
required_assignment_it->second) {
if (required_assignment.time == time) {
// Sanity check that there is only one required at time.
CHECK(!required_assignment_at_time);
required_assignment_at_time = required_assignment;
}
}
}
return required_assignment_at_time;
}
absl::optional<RequiredMemoryAssignment>
AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
const AllocationValue::Use& use) const {
absl::optional<RequiredMemoryAssignment> required_assignment;
for (const HloPosition& position : use.aliases) {
const HloValue* value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
position.instruction, position.index);
int64 time =
hlo_live_range_.instruction_schedule().at(position.instruction);
absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
RequiredMemoryAssignmentAt(value, time);
if (required_assignment == absl::nullopt) {
required_assignment = required_assignment_for_alias;
} else {
CHECK(required_assignment_for_alias == absl::nullopt ||
required_assignment->equals_ignoring_time(
*required_assignment_for_alias));
}
}
return required_assignment;
}
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
@ -1429,24 +1493,6 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks(
CommitChunk(buffer_interval, chunk_candidate);
}
absl::optional<RequiredMemoryAssignment>
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
int64 time) const {
auto required_assignment_it = required_assignments_.find(buffer);
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
if (required_assignment_it != required_assignments_.end()) {
for (const RequiredMemoryAssignment& required_assignment :
required_assignment_it->second) {
if (required_assignment.time == time) {
// Sanity check that there is only one required at time.
CHECK(!required_assignment_at_time);
required_assignment_at_time = required_assignment;
}
}
}
return required_assignment_at_time;
}
bool AlternateMemoryBestFitHeap::AllocateSegment(
const AllocationRequest& request) {
auto allocation_sequence = request.allocation_value->allocation_sequence();
@ -1457,7 +1503,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
MemorySpaceAssignment::Allocation* allocation =
GetLiveAllocationAt(*allocation_sequence, request.end_time);
CHECK_NE(allocation, nullptr);
allocation->AddUse(request.use);
allocation->AddUse(request.use->hlo_use);
return true;
}
@ -1467,8 +1513,9 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
<< request.allocation_value->ToShortString() << " ("
<< request.start_time << ", " << request.end_time
<< ") latest prefetch = " << request.latest_prefetch_time
<< " last use = " << request.allocation_value->use_times().back()
<< " use = " << request.use.ToString() << ". Size = " << request.size
<< " last use = " << request.allocation_value->uses().back().time
<< " use = " << request.use->hlo_use.ToString()
<< ". Size = " << request.size
<< ", def pos = " << defining_position.ToString();
CHECK_LE(request.start_time, request.end_time);
@ -1483,8 +1530,21 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
if (required_assignment_at_start) {
required_memory_space_at_start = required_assignment_at_start->memory_space;
}
// Find required assignment both for the use and its aliases. If they are both
// non-nullopt, then make sure they require the same assignment.
auto required_assignment_at_end = RequiredMemoryAssignmentAt(
request.allocation_value->value(), request.end_time);
auto aliased_required_assignment_at_end =
AliasedRequiredAssignmentForUse(*request.use);
if (required_assignment_at_end != aliased_required_assignment_at_end) {
if (required_assignment_at_end == absl::nullopt) {
required_assignment_at_end = aliased_required_assignment_at_end;
} else {
CHECK(aliased_required_assignment_at_end == absl::nullopt ||
aliased_required_assignment_at_end->equals_ignoring_time(
*required_assignment_at_end));
}
}
absl::optional<MemorySpace> required_memory_space_at_end;
if (required_assignment_at_end) {
required_memory_space_at_end = required_assignment_at_end->memory_space;
@ -1553,7 +1613,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
VLOG(3)
<< "Not trying to prefetch because use requires buffer in default mem.";
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
return true;
}
@ -1577,7 +1637,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
// If a copy wasn't inserted, then add this use to the latest allocation in
// default memory.
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
return true;
}
@ -1746,7 +1806,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
chunk_candidate->chunk, request.start_time, request.end_time));
}
request.allocation_value->allocation_sequence()->back()->AddUse(
request.use);
request.use->hlo_use);
return true;
}
return false;
@ -1833,7 +1893,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
if (!eviction_scheduled) {
// If the eviction couldn't be scheduled, then fail. This buffer will be
// kept in the default memory.
VLOG(3) << "Bailing: Could not evict " << request.use.ToString()
VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
<< " because we hit the limit of maximum asynchronous copies "
<< "between "
<< hlo_live_range_.flattened_instruction_sequence()
@ -1868,7 +1928,8 @@ bool AlternateMemoryBestFitHeap::Prefetch(
earliest_prefetch_time =
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
}
options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time,
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
earliest_prefetch_time,
request.latest_prefetch_time);
VLOG(3) << "Trying prefetch picker = "
<< options_.prefetch_interval_picker->ToDebugString();
@ -1922,7 +1983,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
request.allocation_value->allocation_sequence());
request.allocation_value->allocation_sequence()->back()->AddUse(
request.use);
request.use->hlo_use);
prefetch_failed_due_to_async_copy_ = false;
return true;
}
@ -1938,11 +1999,11 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
if (!preferred_offset) {
// Find a chunk that's as long living as possible iterating in reverse over
// the use times.
for (auto use_time = request.allocation_value->use_times().rbegin();
use_time != request.allocation_value->use_times().rend() &&
*use_time >= end_time;
++use_time) {
alternate_mem_interval->end = *use_time;
for (auto use_it = request.allocation_value->uses().rbegin();
use_it != request.allocation_value->uses().rend() &&
use_it->time >= end_time;
++use_it) {
alternate_mem_interval->end = use_it->time;
ChunkCandidate chunk_candidate =
FindChunkCandidate(*alternate_mem_interval);
if (chunk_candidate.heap_size <= available_heap_size()) {

View File

@ -620,6 +620,18 @@ class MemorySpaceAssignment {
// add.5, operand 0
class AllocationValue {
public:
// This data structure wraps an HloUse and adds additional metadata that are
// useful for allocation.
struct Use {
// The wrapped HloUse object.
HloUse hlo_use;
// The logical time this use is scheduled.
int64 time;
// All the positions where this use aliases with. The aliased positions
// must get the same allocation.
std::vector<HloPosition> aliases;
};
AllocationValue(const HloValue* value, const HloPosition& position)
: value_(value), defining_position_(position) {}
@ -627,8 +639,8 @@ class MemorySpaceAssignment {
const HloInstruction* defining_instruction() const {
return defining_position().instruction;
}
const std::vector<HloUse>& uses() const { return uses_; }
const std::vector<int64>& use_times() const { return use_times_; }
const std::vector<Use>& uses() const { return uses_; }
std::vector<Use>& uses() { return uses_; }
const HloValue* value() const { return value_; }
const HloComputation* computation() const {
return defining_instruction()->parent();
@ -636,8 +648,7 @@ class MemorySpaceAssignment {
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
void AddUse(const HloUse& use, int64 use_time) {
uses_.push_back(use);
use_times_.push_back(use_time);
uses_.push_back({use, use_time, {}});
}
std::string ToString() const;
@ -646,8 +657,7 @@ class MemorySpaceAssignment {
private:
const HloValue* value_;
HloPosition defining_position_;
std::vector<HloUse> uses_;
std::vector<int64> use_times_;
std::vector<Use> uses_;
AllocationSequence allocation_sequence_;
};
@ -769,10 +779,18 @@ struct RequiredMemoryAssignment {
int64 time;
absl::optional<HeapSimulator::Chunk> chunk;
bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && chunk == other.chunk;
}
bool operator==(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && time == other.time &&
chunk == other.chunk;
}
bool operator!=(const RequiredMemoryAssignment& other) const {
return !(*this == other);
}
};
// A struct representing an asynchronous copy with its logical start and end
@ -880,7 +898,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
bool allow_no_copy_alternate_mem_allocation;
absl::optional<int64> earliest_prefetch_time;
absl::optional<int64> preferred_offset;
HloUse use;
const MemorySpaceAssignment::AllocationValue::Use* use;
MemorySpaceAssignment::AllocationValue* allocation_value;
};
@ -890,10 +908,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
// Returns the required assignment at a particular time, if available.
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
const HloValue* buffer, int64 time) const;
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
@ -914,6 +928,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
bool AllocateColocatedIntervals(
const std::vector<const BufferInterval*>& colocated_intervals);
// Go through all the uses in the AllocationValues and find the aliasing
// positions.
void FindAliases(std::vector<AllocationValue>* allocation_values) const;
// Finds an allocation for an allocation request for a segment (see the
// documentation for AllocationRequest above how a segment is defined).
//
@ -950,12 +968,14 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
const AllocationRequest& request, absl::optional<int64> preferred_offset,
BufferInterval* alternate_mem_interval) const;
// At the end of an allocation with a sequential call (while, conditional, and
// call), this function adds the necessary aliased assignments within the
// called computations.
void AddAliasedRequiredAssignmentsForSequentialCall(
const HloUse& use,
const MemorySpaceAssignment::Allocation* aliased_allocation);
// Returns the required assignment at a particular time, if available.
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
const HloValue* buffer, int64 time) const;
// Searches for aliases in the use for a required assignment, and returns it
// if found.
absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
const AllocationValue::Use& use) const;
// Propagates aliased required assignment for a given position.
void AddAliasedRequiredAssignment(

View File

@ -1635,7 +1635,8 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
%constant.5 = s32[1]{0:T(128)} constant({1})
%prev.4 = s32[6]{0:T(128)} parameter(0)
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %constant.5, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
%neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
}
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
@ -1665,6 +1666,62 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
kDefaultMemorySpace);
}
TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
// Ensure that a dynamic update slice within a while loop is able to get an
// alternate memory allocation.
absl::string_view hlo_string = R"(
HloModule Module, is_scheduled=true
fused_computation {
param0 = f32[2,3] parameter(0)
constant.1 = f32[] constant(0)
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
}
%WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
%body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
%fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
%multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
}
%WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
%cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
%param_iter = f32[] parameter(1)
%param_data = f32[2,3]{1,0} parameter(0)
%p2 = f32[2,3]{1,0} parameter(2)
%copy1 = f32[2,3]{1,0} copy(param_data)
%copy2 = f32[2,3]{1,0} copy(p2)
%tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
%while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
%get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
const HloInstruction* while_op =
module->entry_computation()->GetInstructionWithName("while");
if (GetParam()) {
EXPECT_EQ(
ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
kAlternateMemorySpace);
}
}
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
// Having control_predecessors on an HLO was preventing us from DCEing an op
// that doesn't have any users (tuple.1). The scheduler assumes the graph is