[XLA] Better support for while loops in memory space assignment.

This CL makes changes to support the aliasing requirements of while loops. It
introduces the AllocationValue object which is similar to HloValue. While
HloValue may include multiple HloPositions across multiple Computations (that
alias with each other), there is one AllocationValue object per non-trivial
(excluding GTE, Tuple, and Bitcast) position. This data structure allows memory
space assignment to visit positions and uses in each computation separately, and
then propagate the aliased allocation decisions to other AllocationValue
objects.

Using this new data structure, memory space assignment now properly propagates
aliased positions and uses using required assignments.

This CL also introduces GetTransitiveColocationIntervals to HeapSimulator that
returns colocated intervals instead of HloValues. Using this API, memory space
assignment represents aliased logical times where AllocationValues must have the same
buffer assignment (e.g. while loop body parameter and while loop output).

PiperOrigin-RevId: 306260022
Change-Id: I169cc2a6ee78c5ad9b7bcc366e519f7c5750a4e2
This commit is contained in:
Berkin Ilbeyi 2020-04-13 10:34:24 -07:00 committed by TensorFlower Gardener
parent e646d8878a
commit 517e0274df
5 changed files with 1081 additions and 266 deletions

View File

@ -1392,23 +1392,22 @@ Status BufferAssigner::AssignPresetBuffers(
}
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
const HloDataflowAnalysis& dataflow_analysis =
alias_analysis.dataflow_analysis();
for (auto& position_and_chunk : preset_assignments_->chunks()) {
const HloPosition& position = position_and_chunk.first;
const HloValue& value = dataflow_analysis.GetUniqueValueAt(
position.instruction, position.index);
VLOG(3) << "Preset allocation for value: " << value.ToShortString();
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
auto preset_allocations_iter = preset_allocations.find(value.color());
CHECK(preset_allocations_iter != preset_allocations.end())
<< "No preset value allocation for color " << value.color() << " for "
<< value.ToShortString() << " found.";
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
chunk.size);
const HloPosition& defining_position = position_and_chunk.first;
const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
defining_position.instruction, defining_position.index);
for (const HloValue* value : buffer.values()) {
VLOG(3) << "Preset allocation for value: " << value->ToShortString();
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
auto preset_allocations_iter = preset_allocations.find(value->color());
CHECK(preset_allocations_iter != preset_allocations.end())
<< "No preset value allocation for color " << value->color()
<< " for " << value->ToShortString() << " found.";
preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
chunk.size);
}
const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
assigned_buffers->insert(&buffer);
}

View File

@ -835,13 +835,6 @@ TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
// Set only one preset assignment for while data and its aliases.
auto preset_assignments = absl::make_unique<PresetAssignments>();
preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({cond_param, {1}},
{/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({body_param, {1}},
{/*offset=*/100, /*size=*/40});
preset_assignments->add_chunk({body_data_next, {}},
{/*offset=*/100, /*size=*/40});
preset_assignments->assignment_information_for_space(/*memory_space=*/1)
->size = 140;

File diff suppressed because it is too large Load Diff

View File

@ -363,7 +363,7 @@ class MemorySpaceAssignment {
class Allocation {
public:
Allocation(HloPosition defining_position, MemorySpace memory_space,
Chunk chunk, int64 start_time, int64 end_time)
absl::optional<Chunk> chunk, int64 start_time, int64 end_time)
: defining_position_(defining_position),
memory_space_(memory_space),
chunk_(chunk),
@ -393,7 +393,7 @@ class MemorySpaceAssignment {
const std::vector<HloUse>& uses() const { return uses_; }
MemorySpace memory_space() const { return memory_space_; }
Chunk chunk() const { return chunk_; }
Chunk chunk() const { return *chunk_; }
void set_start_time(int64 start_time) { start_time_ = start_time; }
int64 start_time() const { return start_time_; }
int64 end_time() const { return end_time_; }
@ -405,10 +405,14 @@ class MemorySpaceAssignment {
HloInstruction* tuple,
ShapeIndex shape_index);
// Recursively create kGetTupleElement instructions if the defining position
// shape is not an array. Returns the new instruction that has array shape.
HloInstruction* AddGetTupleElements();
HloPosition defining_position_;
std::vector<HloUse> uses_;
MemorySpace memory_space_;
Chunk chunk_;
absl::optional<Chunk> chunk_;
int64 start_time_;
int64 end_time_;
};
@ -417,8 +421,8 @@ class MemorySpaceAssignment {
class CopyAllocation : public Allocation {
public:
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
Chunk chunk, int64 start_time, int64 end_time,
int64 copy_done_schedule_before_time)
absl::optional<Chunk> chunk, int64 start_time,
int64 end_time, int64 copy_done_schedule_before_time)
: Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
start_time, end_time),
prev_allocation_(prev_allocation),
@ -476,11 +480,105 @@ class MemorySpaceAssignment {
};
using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
struct ValueAndAllocationSequence {
const HloValue* value;
AllocationSequence sequence;
// AllocationValue is used to break up HloValues for each non-trivial position
// (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An
// HloValue may include positions and uses that alias with each other across
// multiple computations. We use this class to break these HloValues such that
// every AllocationValue has one defining position (that may alias with other
// AllocationValues). The uses field of the AllocationValue contains only the
// direct uses of the AllocationValue's defining position.
//
// For example, consider the following HLO snippet:
//
// Body {
// body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
// get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param),
// index=0
// ...
// ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...)
// }
//
// Cond {
// cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
// ...
// }
//
// add.4 = f32[4,3]{1,0} add(...)
// tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...)
// while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond
// get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0
// add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...)
//
// This contains an HloValue that looks like the following:
// positions:
// add.4
// body_param {0}
// get-tuple-element.3
// tuple {0}
// cond_param {0}
// tuple.1 {0}
// while {0}
// get-tuple-element.5
// uses:
// add.1, operand 0
// tuple, operand 0
// while, operand 0 {0}
// add.5, operand 0
//
// We break this HloValue up into the following AllocationValues for each
// non-trivial position:
// AllocationValue1: computation = Entry
// position:
// add.4
// uses:
// while, operand 0 {0}
// AllocationValue2: computation = Cond
// position:
// cond_param {0}
// uses:
// AllocationValue3: computation = Body
// position:
// body_param {0}
// uses:
// add.1, operand 0
// tuple, operand 0
// AllocationValue4: computation = Entry
// position:
// while {0}
// uses:
// add.5, operand 0
class AllocationValue {
public:
AllocationValue(const HloValue* value, const HloPosition& position)
: value_(value), defining_position_(position) {}
const HloPosition& defining_position() const { return defining_position_; }
const HloInstruction* defining_instruction() const {
return defining_position().instruction;
}
const std::vector<HloUse>& uses() const { return uses_; }
const std::vector<int64>& use_times() const { return use_times_; }
const HloValue* value() const { return value_; }
const HloComputation* computation() const {
return defining_instruction()->parent();
}
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
void AddUse(const HloUse& use, int64 use_time) {
uses_.push_back(use);
use_times_.push_back(use_time);
}
std::string ToString() const;
std::string ToShortString() const;
private:
const HloValue* value_;
HloPosition defining_position_;
std::vector<HloUse> uses_;
std::vector<int64> use_times_;
AllocationSequence allocation_sequence_;
};
using AllocationSequenceList = std::vector<ValueAndAllocationSequence>;
// Runs the MemorySpaceAssignment pass.
static StatusOr<std::unique_ptr<PresetAssignments>> Run(
@ -545,7 +643,7 @@ class MemorySpaceAssignment {
Options options_;
std::vector<HloInstruction*> flattened_instructions_;
absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
AllocationSequenceList allocation_sequence_list_;
AllocationSequence allocations_;
std::unique_ptr<PresetAssignments> preset_assignments_;
// These maps hold vectors of new instructions that need to be scheduled after
@ -562,6 +660,7 @@ class MemorySpaceAssignment {
struct RequiredMemoryAssignment {
MemorySpaceAssignment::MemorySpace memory_space;
int64 time;
absl::optional<HeapSimulator::Chunk> chunk;
};
// A struct representing an asynchronous copy with its logical start and end
@ -614,14 +713,15 @@ class AsynchronousCopyOrdering {
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
public:
using MemorySpace = MemorySpaceAssignment::MemorySpace;
using AllocationValue = MemorySpaceAssignment::AllocationValue;
AlternateMemoryBestFitHeap(
MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list,
MemorySpaceAssignment::AllocationSequence* allocations,
const MemorySpaceAssignment::Options& options,
const HloAliasAnalysis& alias_analysis,
const HloLiveRange& hlo_live_range)
: GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
allocation_sequence_list_(allocation_sequence_list),
allocations_(allocations),
options_(options),
alias_analysis_(alias_analysis),
hlo_live_range_(hlo_live_range) {
@ -632,7 +732,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
}
// Allocates a buffer in preferred memory with whole program lifetime and
// enables prefetching prefech_candidate from default memory across program
// enables prefetching prefetch_candidate from default memory across program
// boundaries.
void AllocateCrossProgramPrefetchBuffer(
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
@ -660,12 +760,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
struct AllocationRequest {
int64 start_time;
int64 end_time;
const std::vector<int64>* use_times;
int64 latest_prefetch_time;
int64 size;
absl::optional<int64> preferred_offset;
HloUse use;
const HloValue* buffer;
MemorySpaceAssignment::AllocationSequence* allocations;
MemorySpaceAssignment::AllocationValue* allocation_value;
};
// Given an allocation sequence, returns the live allocation at time with a
@ -674,15 +773,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
// Returns true if a buffer is required to be in default memory at a
// particular time. A buffer may be required to be in default memory because
// it is a parameter in default memory or an ouput in default memory.
bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
// Returns the required assignment at a particular time, if available.
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
const HloValue* buffer, int64 time) const;
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
// Returns true if the use is allowed in the alternate memory.
bool IsUseAllowedInAlternateMemory(const HloUse& use) const;
// Given an HloValue, creates AllocationValue objects and corresponding
// AllocationSequences and appends them into allocation_sequence_list_.
void CreateAllocationValues(const HloValue* value,
std::vector<AllocationValue>* allocation_values);
// Finds an allocation for the given interval.
//
// It performs three things in the following order:
@ -715,10 +821,21 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// availability if no preferred offset is given, or at the preferred_offset if
// it is given.
absl::optional<ChunkCandidate> FindBestChunkCandidate(
int64 end_time, const std::vector<int64>& use_times,
absl::optional<int64> preferred_offset,
const AllocationRequest& request, absl::optional<int64> preferred_offset,
BufferInterval* alternate_mem_interval) const;
// At the end of an allocation with a sequential call (while, conditional, and
// call), this function adds the necessary aliased assignments within the
// called computations.
void AddAliasedRequiredAssignmentsForSequentialCall(
const HloUse& use,
const MemorySpaceAssignment::Allocation* aliased_allocation);
// Propagates aliased required assignment for a given position.
void AddAliasedRequiredAssignment(
const HloInstruction* instruction, ShapeIndex index,
const MemorySpaceAssignment::Allocation* aliased_allocation);
// Adds input and outputs as required assignments.
void AddInputAndOutputRequiredAssignments();
@ -734,7 +851,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
std::vector<const BufferInterval*> GetSortedColocatedIntervals(
const BufferInterval& interval) const;
// Since the allocations are recorded to the AllocationSequenceList, we don't
// Since the allocations are recorded to the AllocationSequence, we don't
// maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
// to avoid unnecessarily adding the chunk to the chunk map.
void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
@ -749,8 +866,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// Adds an asynchronous copy to the allocations.
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time,
int64 end_time, int64 copy_done_schedule_before_time,
MemorySpace memory_space, absl::optional<Chunk> chunk,
int64 start_time, int64 end_time,
int64 copy_done_schedule_before_time,
MemorySpaceAssignment::AllocationSequence* allocations);
// This method is used for committing the chunk candidate but adding it to
@ -768,7 +886,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
return options_.max_size_in_bytes - reserved_in_bytes_;
}
MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list_;
MemorySpaceAssignment::AllocationSequence* allocations_;
const MemorySpaceAssignment::Options& options_;
const HloAliasAnalysis& alias_analysis_;
const HloLiveRange& hlo_live_range_;
@ -784,6 +902,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
required_assignments_;
// Number of bytes reserved in alternate memory space.
int64 reserved_in_bytes_ = 0;
int64 global_max_time_;
};
} // namespace xla

View File

@ -103,6 +103,21 @@ class MemorySpaceAssignmentTest : public HloTestBase,
return true;
};
// Only check parameters in default memory if the original module didn't
// have the parameters in alternate memory.
bool check_parameters_in_default_memory = true;
for (const HloInstruction* parameter :
module->entry_computation()->parameter_instructions()) {
ShapeUtil::ForEachSubshape(
parameter->shape(),
[&](const Shape& subshape, const ShapeIndex& /*index*/) {
if (subshape.has_layout() &&
subshape.layout().memory_space() == kAlternateMemorySpace) {
check_parameters_in_default_memory = false;
}
});
}
MemorySpaceAssignment::Options options;
options.alternate_memory_space = kAlternateMemorySpace;
options.max_size_in_bytes = 128;
@ -125,6 +140,9 @@ class MemorySpaceAssignmentTest : public HloTestBase,
MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
options)
.ValueOrDie();
if (check_parameters_in_default_memory) {
CheckParametersInDefaultMemory(module);
}
CheckPresetAssignments(preset_assignments.get());
return preset_assignments;
}
@ -148,6 +166,24 @@ class MemorySpaceAssignmentTest : public HloTestBase,
}
}
void CheckParametersInDefaultMemory(const HloModule* module) {
// Check that all the entry parameter subshapes are placed in default
// memory.
const HloComputation* entry_computation = module->entry_computation();
for (const HloInstruction* parameter :
entry_computation->parameter_instructions()) {
ShapeUtil::ForEachSubshape(
parameter->shape(),
[&](const Shape& subshape, const ShapeIndex& /*index*/) {
if (subshape.has_layout()) {
EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace)
<< "Parameter not in default memory: "
<< parameter->ToString();
}
});
}
}
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
@ -1250,11 +1286,282 @@ TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) {
if (instruction->opcode() == HloOpcode::kWhile) {
const Shape& while_subshape =
ShapeUtil::GetSubshape(instruction->shape(), {0});
EXPECT_NE(while_subshape.layout().memory_space(), kAlternateMemorySpace);
// We expect shape {0} to either be in default memory for the entire while
// loop or there has to be an eviction within the while loop.
if (while_subshape.layout().memory_space() == kAlternateMemorySpace) {
const HloInstruction* body_param =
instruction->while_body()->parameter_instruction(0);
const HloInstruction* gte = nullptr;
for (const HloInstruction* user : body_param->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement &&
user->tuple_index() == 0) {
gte = user;
break;
}
}
EXPECT_NE(gte, nullptr);
const HloInstruction* copy_start = nullptr;
for (const HloInstruction* user : gte->users()) {
if (user->opcode() == HloOpcode::kCopyStart) {
copy_start = user;
break;
}
}
EXPECT_NE(copy_start, nullptr);
const Shape& copy_start_subshape =
ShapeUtil::GetSubshape(copy_start->shape(), {0});
EXPECT_NE(copy_start_subshape.layout().memory_space(),
kAlternateMemorySpace);
}
}
}
}
TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) {
absl::string_view hlo_string = R"(
HloModule WhileAllocationBug, is_scheduled=true
%WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
%constant.1 = f32[] constant(1)
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
}
%WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
%WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
%constant.1 = f32[] constant(1)
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
}
%WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
%param_iter = f32[] parameter(1)
%param_data = f32[4,3]{1,0} parameter(0)
%p2 = f32[4,3]{1,0} parameter(2)
%neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
%neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
%neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
%neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
%neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
%neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
%neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
%add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
%tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
%while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
%get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
%add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
%get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
%tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter)
%while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
%get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) {
// Tests against while live ranges being incorrect and the verifier
// complaining about a conflict.
absl::string_view hlo_string = R"(
HloModule WhileAllocationBug, is_scheduled=true
%WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
%neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
%neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
%neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
%neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
%neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
%neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
%neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
%neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
%neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
%neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
%neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
%constant.1 = f32[] constant(1)
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
}
%WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
%param_iter = f32[] parameter(1)
%param_data = f32[4,3]{1,0} parameter(0)
%p2 = f32[4,3]{1,0} parameter(2)
%neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
%neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
%neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
%neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
%neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
%neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
%neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
%add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
%tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
%while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
%get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
%get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
%add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
// Tests against a bug when there are consecutive while loops with one buffer
// (the value doesn't change in the buffer), the parameter can be colored in
// the alternate memory space.
absl::string_view hlo_string = R"(
HloModule WhileAllocationBug, is_scheduled=true
%WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
%neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
%neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
%neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
%neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
%neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
%neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
%neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
%neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
%neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
%neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
%neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
%constant.1 = f32[] constant(1)
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
}
%WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
%WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
%neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
%neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
%neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
%neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
%neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
%neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
%neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
%neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
%neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
%neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
%neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
%constant.1 = f32[] constant(1)
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
}
%WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
%constant = f32[] constant(50)
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
}
ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
%param_iter = f32[] parameter(1)
%param_data = f32[4,3]{1,0} parameter(0)
%p2 = f32[4,3]{1,0} parameter(2)
%neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
%neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
%neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
%neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
%neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
%neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
%neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
%add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
%tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
%while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
%get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
%add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
%tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter)
%while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
%get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
%get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1
ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
// Having control_predecessors on an HLO was preventing us from DCEing an op
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
@ -2070,12 +2377,6 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
/*max_prefetch_interval=*/25);
int64 memory_space_across_while = kDefaultMemorySpace;
bool allocate_across_sequential_calls = GetParam();
if (allocate_across_sequential_calls) {
memory_space_across_while = kAlternateMemorySpace;
}
// Index {0} of the while loop argument is not written inside the while loop,
// so it can be trivially placed in the alternate memory space.
*ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
@ -2087,12 +2388,11 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
LayoutUtil::MakeLayout(
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kDefaultMemorySpace);
// Index {2} of the while loop argument is placed in the alternate memory if
// we enable the allocate_across_sequential_calls option.
// Index {2} of the while loop is placed in the default memory.
*ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
LayoutUtil::MakeLayout(
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
memory_space_across_while);
kDefaultMemorySpace);
// Expect the layout for the while loop and its aliased buffers.
EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));