[XLA] Better alias handling in memory space assignment.
Instead of using ad-hoc alias rules (for kWhile and kConditional), we use the aliases reported by HloAliasAnalysis. Using this, we can ensure aliased values get the same allocation. In practice, this enables us to share the buffer of DynamicUpdateSlice in a while loop in alternate memory. For sharing DUS buffers that are not in while loops, we need to make changes to HloDataflowAnalysis and copy insertion. PiperOrigin-RevId: 315303035 Change-Id: I5f1057ed7df2b1f09138512be248cdc09533f54f
This commit is contained in:
parent
d2e0f75817
commit
60d63428b1
@ -432,8 +432,8 @@ std::string MemorySpaceAssignment::AllocationValue::ToString() const {
|
|||||||
absl::StrAppend(&out, "\n position:\n");
|
absl::StrAppend(&out, "\n position:\n");
|
||||||
absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
|
absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
|
||||||
absl::StrAppend(&out, " uses:\n");
|
absl::StrAppend(&out, " uses:\n");
|
||||||
for (const HloUse& use : uses_) {
|
for (const Use& use : uses_) {
|
||||||
absl::StrAppend(&out, " ", use.ToString(), "\n");
|
absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -515,6 +515,53 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AlternateMemoryBestFitHeap::FindAliases(
|
||||||
|
std::vector<AllocationValue>* allocation_values) const {
|
||||||
|
absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
|
||||||
|
values_by_defining_inst;
|
||||||
|
for (AllocationValue& value : *allocation_values) {
|
||||||
|
CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
|
||||||
|
values_by_defining_inst[value.defining_instruction()] = &value;
|
||||||
|
}
|
||||||
|
auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
|
||||||
|
AllocationValue::Use* use) {
|
||||||
|
auto aliased_value_it = values_by_defining_inst.find(instruction);
|
||||||
|
if (aliased_value_it != values_by_defining_inst.end()) {
|
||||||
|
VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to "
|
||||||
|
<< aliased_value_it->second->ToShortString();
|
||||||
|
use->aliases.push_back(aliased_value_it->second->defining_position());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (AllocationValue& value : *allocation_values) {
|
||||||
|
for (AllocationValue::Use& use : value.uses()) {
|
||||||
|
// Find any aliases with the instruction itself (operand and output must
|
||||||
|
// alias).
|
||||||
|
maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
|
||||||
|
|
||||||
|
// Find any aliases with the parameters of called computations.
|
||||||
|
for (const HloComputation* called_computation :
|
||||||
|
use.hlo_use.instruction->called_computations()) {
|
||||||
|
for (const HloInstruction* parameter_instruction :
|
||||||
|
called_computation->parameter_instructions()) {
|
||||||
|
maybe_add_alias_with_instruction(parameter_instruction, &use);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case for kWhile: the root of the body computation must alias as
|
||||||
|
// well.
|
||||||
|
if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
|
||||||
|
HloPosition root_alias{
|
||||||
|
use.hlo_use.instruction->while_body()->root_instruction(),
|
||||||
|
use.hlo_use.operand_index};
|
||||||
|
VLOG(3) << "Adding while body root aliasing for use "
|
||||||
|
<< use.hlo_use.ToString() << " to " << root_alias;
|
||||||
|
use.aliases.push_back(root_alias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
|
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
|
||||||
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||||
@ -675,18 +722,18 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
|||||||
// multiple called computations), determine if the parameter->first use
|
// multiple called computations), determine if the parameter->first use
|
||||||
// dependency is short.
|
// dependency is short.
|
||||||
int64 conditional_time = instruction_schedule.at(use.instruction);
|
int64 conditional_time = instruction_schedule.at(use.instruction);
|
||||||
for (const HloUse& other_use : value.uses()) {
|
for (const AllocationValue::Use& other_use : value.uses()) {
|
||||||
if (other_use.instruction != use.instruction) {
|
if (other_use.hlo_use.instruction != use.instruction) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
HloComputation* called_computation =
|
HloComputation* called_computation =
|
||||||
use.instruction->called_computations().at(other_use.operand_number -
|
use.instruction->called_computations().at(
|
||||||
1);
|
other_use.hlo_use.operand_number - 1);
|
||||||
const HloInstruction* parameter_instruction =
|
const HloInstruction* parameter_instruction =
|
||||||
called_computation->parameter_instruction(0);
|
called_computation->parameter_instruction(0);
|
||||||
HloValue* parameter_value =
|
HloValue* parameter_value =
|
||||||
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
|
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
|
||||||
parameter_instruction, other_use.operand_index);
|
parameter_instruction, other_use.hlo_use.operand_index);
|
||||||
int64 parameter_time = instruction_schedule.at(parameter_instruction);
|
int64 parameter_time = instruction_schedule.at(parameter_instruction);
|
||||||
int64 min_use_time = conditional_time;
|
int64 min_use_time = conditional_time;
|
||||||
for (const HloUse& parameter_use : parameter_value->uses()) {
|
for (const HloUse& parameter_use : parameter_value->uses()) {
|
||||||
@ -947,6 +994,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
for (const auto& colocated_interval : colocated_intervals) {
|
for (const auto& colocated_interval : colocated_intervals) {
|
||||||
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
CreateAllocationValues(colocated_interval->buffer, &allocation_values);
|
||||||
}
|
}
|
||||||
|
FindAliases(&allocation_values);
|
||||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||||
|
|
||||||
// Data structure to contain the preferred offset for a given computation.
|
// Data structure to contain the preferred offset for a given computation.
|
||||||
@ -969,25 +1017,26 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
|
|
||||||
// Iterate over the uses.
|
// Iterate over the uses.
|
||||||
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
|
for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
|
||||||
const HloUse& use = allocation_value.uses().at(use_idx);
|
const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
|
||||||
int64 use_time = instruction_schedule.at(use.instruction);
|
const HloUse hlo_use = use.hlo_use;
|
||||||
|
int64 use_time = instruction_schedule.at(hlo_use.instruction);
|
||||||
int64 latest_prefetch_time = use_time;
|
int64 latest_prefetch_time = use_time;
|
||||||
bool allow_no_copy_alternate_mem_allocation = true;
|
bool allow_no_copy_alternate_mem_allocation = true;
|
||||||
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
|
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
|
||||||
|
|
||||||
// Sequential calls include kWhile, kCall, and kConditional opcodes.
|
// Sequential calls include kWhile, kCall, and kConditional opcodes.
|
||||||
bool is_sequential_call =
|
bool is_sequential_call =
|
||||||
(GetInstructionCallContext(use.instruction->opcode()) ==
|
(GetInstructionCallContext(hlo_use.instruction->opcode()) ==
|
||||||
CallContext::kSequential);
|
CallContext::kSequential);
|
||||||
if (is_sequential_call) {
|
if (is_sequential_call) {
|
||||||
for (const HloComputation* called_computation :
|
for (const HloComputation* called_computation :
|
||||||
use.instruction->called_computations()) {
|
hlo_use.instruction->called_computations()) {
|
||||||
const HloLiveRange::TimeBound& computation_span =
|
const HloLiveRange::TimeBound& computation_span =
|
||||||
hlo_live_range_.computation_span_times().at(called_computation);
|
hlo_live_range_.computation_span_times().at(called_computation);
|
||||||
latest_prefetch_time =
|
latest_prefetch_time =
|
||||||
std::min(computation_span.start, latest_prefetch_time);
|
std::min(computation_span.start, latest_prefetch_time);
|
||||||
}
|
}
|
||||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
|
||||||
// Given an example while loop and flattened schedule (logical times
|
// Given an example while loop and flattened schedule (logical times
|
||||||
// shown on the left):
|
// shown on the left):
|
||||||
//
|
//
|
||||||
@ -1008,10 +1057,10 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
// the interval to time 0-4. This is so that the remaining interval
|
// the interval to time 0-4. This is so that the remaining interval
|
||||||
// (5-6) can be allocated separately and this buffer doesn't waste
|
// (5-6) can be allocated separately and this buffer doesn't waste
|
||||||
// alternate memory space within the while loop body.
|
// alternate memory space within the while loop body.
|
||||||
HloComputation* while_body = use.instruction->while_body();
|
HloComputation* while_body = hlo_use.instruction->while_body();
|
||||||
// We require while body ROOTs to be the last in the schedule.
|
// We require while body ROOTs to be the last in the schedule.
|
||||||
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
|
CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
|
||||||
instruction_schedule.at(use.instruction))
|
instruction_schedule.at(hlo_use.instruction))
|
||||||
<< "While body ROOTs need to be the last in the schedule! "
|
<< "While body ROOTs need to be the last in the schedule! "
|
||||||
"Please run RootInstructionSinker.";
|
"Please run RootInstructionSinker.";
|
||||||
// Replace the use time with the parameter time so that we can decide
|
// Replace the use time with the parameter time so that we can decide
|
||||||
@ -1019,11 +1068,11 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
// look at uses within the while loop body.
|
// look at uses within the while loop body.
|
||||||
use_time =
|
use_time =
|
||||||
instruction_schedule.at(while_body->parameter_instruction(0));
|
instruction_schedule.at(while_body->parameter_instruction(0));
|
||||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
} else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
|
||||||
// Replace the use time with the earliest parameter of called
|
// Replace the use time with the earliest parameter of called
|
||||||
// computations.
|
// computations.
|
||||||
for (const HloComputation* called_computation :
|
for (const HloComputation* called_computation :
|
||||||
use.instruction->called_computations()) {
|
hlo_use.instruction->called_computations()) {
|
||||||
use_time = std::min(
|
use_time = std::min(
|
||||||
use_time, instruction_schedule.at(
|
use_time, instruction_schedule.at(
|
||||||
called_computation->parameter_instruction(0)));
|
called_computation->parameter_instruction(0)));
|
||||||
@ -1033,8 +1082,8 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
|
|
||||||
// Add a required assignment in default memory if the use not allowed in
|
// Add a required assignment in default memory if the use not allowed in
|
||||||
// alternate memory.
|
// alternate memory.
|
||||||
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
|
if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
|
||||||
AddRequiredAssignment(allocation_value.value(), use.instruction,
|
AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
|
||||||
MemorySpace::kDefault, use_time);
|
MemorySpace::kDefault, use_time);
|
||||||
} else if (use_idx > 0) {
|
} else if (use_idx > 0) {
|
||||||
// We allow buffers in alternate memory that are passed into
|
// We allow buffers in alternate memory that are passed into
|
||||||
@ -1043,14 +1092,16 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
// alternate memory allocation, subsequent uses cannot use the same
|
// alternate memory allocation, subsequent uses cannot use the same
|
||||||
// alternate memory allocation in order not to clobber data. So we force
|
// alternate memory allocation in order not to clobber data. So we force
|
||||||
// default memory allocation for these subsequent uses.
|
// default memory allocation for these subsequent uses.
|
||||||
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
|
const AllocationValue::Use& previous_use =
|
||||||
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
|
allocation_value.uses().at(use_idx - 1);
|
||||||
previous_use.instruction != use.instruction) {
|
if (previous_use.hlo_use.instruction->opcode() ==
|
||||||
|
HloOpcode::kConditional &&
|
||||||
|
previous_use.hlo_use.instruction != hlo_use.instruction) {
|
||||||
allow_no_copy_alternate_mem_allocation = false;
|
allow_no_copy_alternate_mem_allocation = false;
|
||||||
earliest_prefetch_time =
|
earliest_prefetch_time =
|
||||||
instruction_schedule.at(previous_use.instruction);
|
instruction_schedule.at(previous_use.hlo_use.instruction);
|
||||||
VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use ("
|
VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
|
||||||
<< use.ToString()
|
<< ") of use (" << hlo_use.ToString()
|
||||||
<< ") is a conditional, so this use will need to evict. "
|
<< ") is a conditional, so this use will need to evict. "
|
||||||
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
<< "Earliest prefetch time = " << *earliest_prefetch_time;
|
||||||
}
|
}
|
||||||
@ -1059,7 +1110,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
// Bitcasts don't define buffers and don't directly consume buffers. Skip
|
// Bitcasts don't define buffers and don't directly consume buffers. Skip
|
||||||
// allocating buffers for bitcast uses. The uses that feed from bitcasts
|
// allocating buffers for bitcast uses. The uses that feed from bitcasts
|
||||||
// will be handled specially.
|
// will be handled specially.
|
||||||
if (use.instruction->opcode() != HloOpcode::kBitcast) {
|
if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) {
|
||||||
AllocationRequest request;
|
AllocationRequest request;
|
||||||
// Rarely, (e.g., when conditional true and false parameters are the
|
// Rarely, (e.g., when conditional true and false parameters are the
|
||||||
// same), definition time can be the time of the conditional and use
|
// same), definition time can be the time of the conditional and use
|
||||||
@ -1072,7 +1123,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
allow_no_copy_alternate_mem_allocation;
|
allow_no_copy_alternate_mem_allocation;
|
||||||
request.earliest_prefetch_time = earliest_prefetch_time;
|
request.earliest_prefetch_time = earliest_prefetch_time;
|
||||||
request.preferred_offset = preferred_offset;
|
request.preferred_offset = preferred_offset;
|
||||||
request.use = use;
|
request.use = &use;
|
||||||
request.allocation_value = &allocation_value;
|
request.allocation_value = &allocation_value;
|
||||||
if (!AllocateSegment(request)) {
|
if (!AllocateSegment(request)) {
|
||||||
// If the allocation finding failed (e.g., due to running out of
|
// If the allocation finding failed (e.g., due to running out of
|
||||||
@ -1085,23 +1136,25 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
|||||||
|
|
||||||
// If there are multiple uses, they can try using the memory allocation
|
// If there are multiple uses, they can try using the memory allocation
|
||||||
// already at the alternate memory.
|
// already at the alternate memory.
|
||||||
definition_time = instruction_schedule.at(use.instruction);
|
definition_time = instruction_schedule.at(hlo_use.instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the use has been a sequential call (e.g. a while loop), the other
|
// Propagate the allocation to any aliases this use might have had.
|
||||||
// colocated intervals must alias with this allocation.
|
|
||||||
if (is_sequential_call) {
|
|
||||||
MemorySpaceAssignment::Allocation* aliased_allocation =
|
MemorySpaceAssignment::Allocation* aliased_allocation =
|
||||||
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
GetLiveAllocationAt(*allocation_value.allocation_sequence(),
|
||||||
use_time);
|
use_time);
|
||||||
AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation);
|
for (const HloPosition& aliased_position : use.aliases) {
|
||||||
// Remember the preferred offset to be used inside while loop body
|
AddAliasedRequiredAssignment(aliased_position.instruction,
|
||||||
// computations.
|
aliased_position.index,
|
||||||
if (aliased_allocation->memory_space() == MemorySpace::kAlternate &&
|
aliased_allocation);
|
||||||
use.instruction->opcode() == HloOpcode::kWhile) {
|
|
||||||
preferred_offset_for_computation[use.instruction->while_body()] =
|
|
||||||
aliased_allocation->chunk().offset;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Special case for while loops since the root offset must agree with
|
||||||
|
// other offsets: remember the preferred offset for the while loop body.
|
||||||
|
if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
|
||||||
|
aliased_allocation->memory_space() == MemorySpace::kAlternate) {
|
||||||
|
preferred_offset_for_computation[hlo_use.instruction->while_body()] =
|
||||||
|
aliased_allocation->chunk().offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!allocation_success) {
|
if (!allocation_success) {
|
||||||
@ -1212,34 +1265,45 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
|
|||||||
pending_required_assignments_.clear();
|
pending_required_assignments_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
|
absl::optional<RequiredMemoryAssignment>
|
||||||
const HloUse& use,
|
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
|
||||||
const MemorySpaceAssignment::Allocation* aliased_allocation) {
|
int64 time) const {
|
||||||
// Add aliased required assignments.
|
auto required_assignment_it = required_assignments_.find(buffer);
|
||||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
|
||||||
HloComputation* while_body = use.instruction->while_body();
|
if (required_assignment_it != required_assignments_.end()) {
|
||||||
HloComputation* while_condition = use.instruction->while_condition();
|
for (const RequiredMemoryAssignment& required_assignment :
|
||||||
AddAliasedRequiredAssignment(while_condition->parameter_instruction(0),
|
required_assignment_it->second) {
|
||||||
use.operand_index, aliased_allocation);
|
if (required_assignment.time == time) {
|
||||||
AddAliasedRequiredAssignment(while_body->parameter_instruction(0),
|
// Sanity check that there is only one required at time.
|
||||||
use.operand_index, aliased_allocation);
|
CHECK(!required_assignment_at_time);
|
||||||
AddAliasedRequiredAssignment(while_body->root_instruction(),
|
required_assignment_at_time = required_assignment;
|
||||||
use.operand_index, aliased_allocation);
|
|
||||||
AddAliasedRequiredAssignment(use.instruction, use.operand_index,
|
|
||||||
aliased_allocation);
|
|
||||||
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
|
|
||||||
HloComputation* called_computation =
|
|
||||||
use.instruction->called_computations().at(use.operand_number - 1);
|
|
||||||
AddAliasedRequiredAssignment(called_computation->parameter_instruction(0),
|
|
||||||
use.operand_index, aliased_allocation);
|
|
||||||
} else {
|
|
||||||
CHECK(use.instruction->opcode() == HloOpcode::kCall);
|
|
||||||
HloComputation* called_computation =
|
|
||||||
use.instruction->called_computations().at(0);
|
|
||||||
AddAliasedRequiredAssignment(
|
|
||||||
called_computation->parameter_instruction(use.operand_number),
|
|
||||||
use.operand_index, aliased_allocation);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return required_assignment_at_time;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<RequiredMemoryAssignment>
|
||||||
|
AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
|
||||||
|
const AllocationValue::Use& use) const {
|
||||||
|
absl::optional<RequiredMemoryAssignment> required_assignment;
|
||||||
|
for (const HloPosition& position : use.aliases) {
|
||||||
|
const HloValue* value =
|
||||||
|
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
|
||||||
|
position.instruction, position.index);
|
||||||
|
int64 time =
|
||||||
|
hlo_live_range_.instruction_schedule().at(position.instruction);
|
||||||
|
absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
|
||||||
|
RequiredMemoryAssignmentAt(value, time);
|
||||||
|
if (required_assignment == absl::nullopt) {
|
||||||
|
required_assignment = required_assignment_for_alias;
|
||||||
|
} else {
|
||||||
|
CHECK(required_assignment_for_alias == absl::nullopt ||
|
||||||
|
required_assignment->equals_ignoring_time(
|
||||||
|
*required_assignment_for_alias));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return required_assignment;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
|
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
|
||||||
@ -1429,24 +1493,6 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks(
|
|||||||
CommitChunk(buffer_interval, chunk_candidate);
|
CommitChunk(buffer_interval, chunk_candidate);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<RequiredMemoryAssignment>
|
|
||||||
AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
|
|
||||||
int64 time) const {
|
|
||||||
auto required_assignment_it = required_assignments_.find(buffer);
|
|
||||||
absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
|
|
||||||
if (required_assignment_it != required_assignments_.end()) {
|
|
||||||
for (const RequiredMemoryAssignment& required_assignment :
|
|
||||||
required_assignment_it->second) {
|
|
||||||
if (required_assignment.time == time) {
|
|
||||||
// Sanity check that there is only one required at time.
|
|
||||||
CHECK(!required_assignment_at_time);
|
|
||||||
required_assignment_at_time = required_assignment;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return required_assignment_at_time;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AlternateMemoryBestFitHeap::AllocateSegment(
|
bool AlternateMemoryBestFitHeap::AllocateSegment(
|
||||||
const AllocationRequest& request) {
|
const AllocationRequest& request) {
|
||||||
auto allocation_sequence = request.allocation_value->allocation_sequence();
|
auto allocation_sequence = request.allocation_value->allocation_sequence();
|
||||||
@ -1457,7 +1503,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
|||||||
MemorySpaceAssignment::Allocation* allocation =
|
MemorySpaceAssignment::Allocation* allocation =
|
||||||
GetLiveAllocationAt(*allocation_sequence, request.end_time);
|
GetLiveAllocationAt(*allocation_sequence, request.end_time);
|
||||||
CHECK_NE(allocation, nullptr);
|
CHECK_NE(allocation, nullptr);
|
||||||
allocation->AddUse(request.use);
|
allocation->AddUse(request.use->hlo_use);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1467,8 +1513,9 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
|||||||
<< request.allocation_value->ToShortString() << " ("
|
<< request.allocation_value->ToShortString() << " ("
|
||||||
<< request.start_time << ", " << request.end_time
|
<< request.start_time << ", " << request.end_time
|
||||||
<< ") latest prefetch = " << request.latest_prefetch_time
|
<< ") latest prefetch = " << request.latest_prefetch_time
|
||||||
<< " last use = " << request.allocation_value->use_times().back()
|
<< " last use = " << request.allocation_value->uses().back().time
|
||||||
<< " use = " << request.use.ToString() << ". Size = " << request.size
|
<< " use = " << request.use->hlo_use.ToString()
|
||||||
|
<< ". Size = " << request.size
|
||||||
<< ", def pos = " << defining_position.ToString();
|
<< ", def pos = " << defining_position.ToString();
|
||||||
CHECK_LE(request.start_time, request.end_time);
|
CHECK_LE(request.start_time, request.end_time);
|
||||||
|
|
||||||
@ -1483,8 +1530,21 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
|||||||
if (required_assignment_at_start) {
|
if (required_assignment_at_start) {
|
||||||
required_memory_space_at_start = required_assignment_at_start->memory_space;
|
required_memory_space_at_start = required_assignment_at_start->memory_space;
|
||||||
}
|
}
|
||||||
|
// Find required assignment both for the use and its aliases. If they are both
|
||||||
|
// non-nullopt, then make sure they require the same assignment.
|
||||||
auto required_assignment_at_end = RequiredMemoryAssignmentAt(
|
auto required_assignment_at_end = RequiredMemoryAssignmentAt(
|
||||||
request.allocation_value->value(), request.end_time);
|
request.allocation_value->value(), request.end_time);
|
||||||
|
auto aliased_required_assignment_at_end =
|
||||||
|
AliasedRequiredAssignmentForUse(*request.use);
|
||||||
|
if (required_assignment_at_end != aliased_required_assignment_at_end) {
|
||||||
|
if (required_assignment_at_end == absl::nullopt) {
|
||||||
|
required_assignment_at_end = aliased_required_assignment_at_end;
|
||||||
|
} else {
|
||||||
|
CHECK(aliased_required_assignment_at_end == absl::nullopt ||
|
||||||
|
aliased_required_assignment_at_end->equals_ignoring_time(
|
||||||
|
*required_assignment_at_end));
|
||||||
|
}
|
||||||
|
}
|
||||||
absl::optional<MemorySpace> required_memory_space_at_end;
|
absl::optional<MemorySpace> required_memory_space_at_end;
|
||||||
if (required_assignment_at_end) {
|
if (required_assignment_at_end) {
|
||||||
required_memory_space_at_end = required_assignment_at_end->memory_space;
|
required_memory_space_at_end = required_assignment_at_end->memory_space;
|
||||||
@ -1553,7 +1613,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
|||||||
VLOG(3)
|
VLOG(3)
|
||||||
<< "Not trying to prefetch because use requires buffer in default mem.";
|
<< "Not trying to prefetch because use requires buffer in default mem.";
|
||||||
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
|
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
|
||||||
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
|
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1577,7 +1637,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
|
|||||||
// If a copy wasn't inserted, then add this use to the latest allocation in
|
// If a copy wasn't inserted, then add this use to the latest allocation in
|
||||||
// default memory.
|
// default memory.
|
||||||
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
|
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
|
||||||
(*prev_allocation_in_default_mem_it)->AddUse(request.use);
|
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1746,7 +1806,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
|
|||||||
chunk_candidate->chunk, request.start_time, request.end_time));
|
chunk_candidate->chunk, request.start_time, request.end_time));
|
||||||
}
|
}
|
||||||
request.allocation_value->allocation_sequence()->back()->AddUse(
|
request.allocation_value->allocation_sequence()->back()->AddUse(
|
||||||
request.use);
|
request.use->hlo_use);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -1833,7 +1893,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
|
|||||||
if (!eviction_scheduled) {
|
if (!eviction_scheduled) {
|
||||||
// If the eviction couldn't be scheduled, then fail. This buffer will be
|
// If the eviction couldn't be scheduled, then fail. This buffer will be
|
||||||
// kept in the default memory.
|
// kept in the default memory.
|
||||||
VLOG(3) << "Bailing: Could not evict " << request.use.ToString()
|
VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
|
||||||
<< " because we hit the limit of maximum asynchronous copies "
|
<< " because we hit the limit of maximum asynchronous copies "
|
||||||
<< "between "
|
<< "between "
|
||||||
<< hlo_live_range_.flattened_instruction_sequence()
|
<< hlo_live_range_.flattened_instruction_sequence()
|
||||||
@ -1868,7 +1928,8 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
earliest_prefetch_time =
|
earliest_prefetch_time =
|
||||||
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
|
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
|
||||||
}
|
}
|
||||||
options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time,
|
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
|
||||||
|
earliest_prefetch_time,
|
||||||
request.latest_prefetch_time);
|
request.latest_prefetch_time);
|
||||||
VLOG(3) << "Trying prefetch picker = "
|
VLOG(3) << "Trying prefetch picker = "
|
||||||
<< options_.prefetch_interval_picker->ToDebugString();
|
<< options_.prefetch_interval_picker->ToDebugString();
|
||||||
@ -1922,7 +1983,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
request.allocation_value->allocation_sequence());
|
request.allocation_value->allocation_sequence());
|
||||||
|
|
||||||
request.allocation_value->allocation_sequence()->back()->AddUse(
|
request.allocation_value->allocation_sequence()->back()->AddUse(
|
||||||
request.use);
|
request.use->hlo_use);
|
||||||
prefetch_failed_due_to_async_copy_ = false;
|
prefetch_failed_due_to_async_copy_ = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1938,11 +1999,11 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
|
|||||||
if (!preferred_offset) {
|
if (!preferred_offset) {
|
||||||
// Find a chunk that's as long living as possible iterating in reverse over
|
// Find a chunk that's as long living as possible iterating in reverse over
|
||||||
// the use times.
|
// the use times.
|
||||||
for (auto use_time = request.allocation_value->use_times().rbegin();
|
for (auto use_it = request.allocation_value->uses().rbegin();
|
||||||
use_time != request.allocation_value->use_times().rend() &&
|
use_it != request.allocation_value->uses().rend() &&
|
||||||
*use_time >= end_time;
|
use_it->time >= end_time;
|
||||||
++use_time) {
|
++use_it) {
|
||||||
alternate_mem_interval->end = *use_time;
|
alternate_mem_interval->end = use_it->time;
|
||||||
ChunkCandidate chunk_candidate =
|
ChunkCandidate chunk_candidate =
|
||||||
FindChunkCandidate(*alternate_mem_interval);
|
FindChunkCandidate(*alternate_mem_interval);
|
||||||
if (chunk_candidate.heap_size <= available_heap_size()) {
|
if (chunk_candidate.heap_size <= available_heap_size()) {
|
||||||
|
@ -620,6 +620,18 @@ class MemorySpaceAssignment {
|
|||||||
// add.5, operand 0
|
// add.5, operand 0
|
||||||
class AllocationValue {
|
class AllocationValue {
|
||||||
public:
|
public:
|
||||||
|
// This data structure wraps an HloUse and adds additional metadata that are
|
||||||
|
// useful for allocation.
|
||||||
|
struct Use {
|
||||||
|
// The wrapped HloUse object.
|
||||||
|
HloUse hlo_use;
|
||||||
|
// The logical time this use is scheduled.
|
||||||
|
int64 time;
|
||||||
|
// All the positions where this use aliases with. The aliased positions
|
||||||
|
// must get the same allocation.
|
||||||
|
std::vector<HloPosition> aliases;
|
||||||
|
};
|
||||||
|
|
||||||
AllocationValue(const HloValue* value, const HloPosition& position)
|
AllocationValue(const HloValue* value, const HloPosition& position)
|
||||||
: value_(value), defining_position_(position) {}
|
: value_(value), defining_position_(position) {}
|
||||||
|
|
||||||
@ -627,8 +639,8 @@ class MemorySpaceAssignment {
|
|||||||
const HloInstruction* defining_instruction() const {
|
const HloInstruction* defining_instruction() const {
|
||||||
return defining_position().instruction;
|
return defining_position().instruction;
|
||||||
}
|
}
|
||||||
const std::vector<HloUse>& uses() const { return uses_; }
|
const std::vector<Use>& uses() const { return uses_; }
|
||||||
const std::vector<int64>& use_times() const { return use_times_; }
|
std::vector<Use>& uses() { return uses_; }
|
||||||
const HloValue* value() const { return value_; }
|
const HloValue* value() const { return value_; }
|
||||||
const HloComputation* computation() const {
|
const HloComputation* computation() const {
|
||||||
return defining_instruction()->parent();
|
return defining_instruction()->parent();
|
||||||
@ -636,8 +648,7 @@ class MemorySpaceAssignment {
|
|||||||
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
|
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
|
||||||
|
|
||||||
void AddUse(const HloUse& use, int64 use_time) {
|
void AddUse(const HloUse& use, int64 use_time) {
|
||||||
uses_.push_back(use);
|
uses_.push_back({use, use_time, {}});
|
||||||
use_times_.push_back(use_time);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
@ -646,8 +657,7 @@ class MemorySpaceAssignment {
|
|||||||
private:
|
private:
|
||||||
const HloValue* value_;
|
const HloValue* value_;
|
||||||
HloPosition defining_position_;
|
HloPosition defining_position_;
|
||||||
std::vector<HloUse> uses_;
|
std::vector<Use> uses_;
|
||||||
std::vector<int64> use_times_;
|
|
||||||
AllocationSequence allocation_sequence_;
|
AllocationSequence allocation_sequence_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -769,10 +779,18 @@ struct RequiredMemoryAssignment {
|
|||||||
int64 time;
|
int64 time;
|
||||||
absl::optional<HeapSimulator::Chunk> chunk;
|
absl::optional<HeapSimulator::Chunk> chunk;
|
||||||
|
|
||||||
|
bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
|
||||||
|
return memory_space == other.memory_space && chunk == other.chunk;
|
||||||
|
}
|
||||||
|
|
||||||
bool operator==(const RequiredMemoryAssignment& other) const {
|
bool operator==(const RequiredMemoryAssignment& other) const {
|
||||||
return memory_space == other.memory_space && time == other.time &&
|
return memory_space == other.memory_space && time == other.time &&
|
||||||
chunk == other.chunk;
|
chunk == other.chunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool operator!=(const RequiredMemoryAssignment& other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// A struct representing an asynchronous copy with its logical start and end
|
// A struct representing an asynchronous copy with its logical start and end
|
||||||
@ -880,7 +898,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
bool allow_no_copy_alternate_mem_allocation;
|
bool allow_no_copy_alternate_mem_allocation;
|
||||||
absl::optional<int64> earliest_prefetch_time;
|
absl::optional<int64> earliest_prefetch_time;
|
||||||
absl::optional<int64> preferred_offset;
|
absl::optional<int64> preferred_offset;
|
||||||
HloUse use;
|
const MemorySpaceAssignment::AllocationValue::Use* use;
|
||||||
MemorySpaceAssignment::AllocationValue* allocation_value;
|
MemorySpaceAssignment::AllocationValue* allocation_value;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -890,10 +908,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
||||||
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
||||||
|
|
||||||
// Returns the required assignment at a particular time, if available.
|
|
||||||
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
|
|
||||||
const HloValue* buffer, int64 time) const;
|
|
||||||
|
|
||||||
// Returns true if this buffer is allowed to be placed in the alternate
|
// Returns true if this buffer is allowed to be placed in the alternate
|
||||||
// memory.
|
// memory.
|
||||||
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
||||||
@ -914,6 +928,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
bool AllocateColocatedIntervals(
|
bool AllocateColocatedIntervals(
|
||||||
const std::vector<const BufferInterval*>& colocated_intervals);
|
const std::vector<const BufferInterval*>& colocated_intervals);
|
||||||
|
|
||||||
|
// Go through all the uses in the AllocationValues and find the aliasing
|
||||||
|
// positions.
|
||||||
|
void FindAliases(std::vector<AllocationValue>* allocation_values) const;
|
||||||
|
|
||||||
// Finds an allocation for an allocation request for a segment (see the
|
// Finds an allocation for an allocation request for a segment (see the
|
||||||
// documentation for AllocationRequest above how a segment is defined).
|
// documentation for AllocationRequest above how a segment is defined).
|
||||||
//
|
//
|
||||||
@ -950,12 +968,14 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
const AllocationRequest& request, absl::optional<int64> preferred_offset,
|
const AllocationRequest& request, absl::optional<int64> preferred_offset,
|
||||||
BufferInterval* alternate_mem_interval) const;
|
BufferInterval* alternate_mem_interval) const;
|
||||||
|
|
||||||
// At the end of an allocation with a sequential call (while, conditional, and
|
// Returns the required assignment at a particular time, if available.
|
||||||
// call), this function adds the necessary aliased assignments within the
|
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
|
||||||
// called computations.
|
const HloValue* buffer, int64 time) const;
|
||||||
void AddAliasedRequiredAssignmentsForSequentialCall(
|
|
||||||
const HloUse& use,
|
// Searches for aliases in the use for a required assignment, and returns it
|
||||||
const MemorySpaceAssignment::Allocation* aliased_allocation);
|
// if found.
|
||||||
|
absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
|
||||||
|
const AllocationValue::Use& use) const;
|
||||||
|
|
||||||
// Propagates aliased required assignment for a given position.
|
// Propagates aliased required assignment for a given position.
|
||||||
void AddAliasedRequiredAssignment(
|
void AddAliasedRequiredAssignment(
|
||||||
|
@ -1635,7 +1635,8 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
|
|||||||
%constant.5 = s32[1]{0:T(128)} constant({1})
|
%constant.5 = s32[1]{0:T(128)} constant({1})
|
||||||
%prev.4 = s32[6]{0:T(128)} parameter(0)
|
%prev.4 = s32[6]{0:T(128)} parameter(0)
|
||||||
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
|
%rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
|
||||||
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %constant.5, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
|
%neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
|
||||||
|
ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
|
||||||
}
|
}
|
||||||
|
|
||||||
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
|
%WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
|
||||||
@ -1665,6 +1666,62 @@ TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
|
|||||||
kDefaultMemorySpace);
|
kDefaultMemorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
|
||||||
|
// Ensure that a dynamic update slice within a while loop is able to get an
|
||||||
|
// alternate memory allocation.
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule Module, is_scheduled=true
|
||||||
|
|
||||||
|
fused_computation {
|
||||||
|
param0 = f32[2,3] parameter(0)
|
||||||
|
constant.1 = f32[] constant(0)
|
||||||
|
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
|
||||||
|
constant.3 = s32[] constant(0)
|
||||||
|
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
|
||||||
|
%body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
|
||||||
|
%get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
|
||||||
|
%get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
|
||||||
|
%fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
|
||||||
|
%multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
|
||||||
|
ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
|
||||||
|
%cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
|
||||||
|
%constant = f32[] constant(50)
|
||||||
|
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
|
||||||
|
%param_iter = f32[] parameter(1)
|
||||||
|
%param_data = f32[2,3]{1,0} parameter(0)
|
||||||
|
%p2 = f32[2,3]{1,0} parameter(2)
|
||||||
|
%copy1 = f32[2,3]{1,0} copy(param_data)
|
||||||
|
%copy2 = f32[2,3]{1,0} copy(p2)
|
||||||
|
%tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
|
||||||
|
%while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
|
||||||
|
%get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
|
||||||
|
ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
AssignMemorySpace(module.get());
|
||||||
|
const HloInstruction* while_op =
|
||||||
|
module->entry_computation()->GetInstructionWithName("while");
|
||||||
|
if (GetParam()) {
|
||||||
|
EXPECT_EQ(
|
||||||
|
ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
|
||||||
|
kAlternateMemorySpace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
||||||
// Having control_predecessors on an HLO was preventing us from DCEing an op
|
// Having control_predecessors on an HLO was preventing us from DCEing an op
|
||||||
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
|
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
|
||||||
|
Loading…
Reference in New Issue
Block a user