[XLA] Better support for while loops in memory space assignment.
This CL makes changes to support the aliasing requirements of while loops. It introduces the AllocationValue object which is similar to HloValue. While HloValue may include multiple HloPositions across multiple Computations (that alias with each other), there is one AllocationValue object per non-trivial (excluding GTE, Tuple, and Bitcast) position. This data structure allows memory space assignment to visit positions and uses in each computation separately, and then propagate the aliased allocation decisions to other AllocationValue objects. Using this new data structure, memory space assignment now properly propagates aliased positions and uses using required assignments. This CL also introduces GetTransitiveColocationIntervals to HeapSimulator that returns colocated intervals instead of HloValues. Using this API, memory space assignment represents aliased logical times where AllocationValues must have the same buffer assignment (e.g. while loop body parameter and while loop output). PiperOrigin-RevId: 306260022 Change-Id: I169cc2a6ee78c5ad9b7bcc366e519f7c5750a4e2
This commit is contained in:
parent
e646d8878a
commit
517e0274df
@ -1392,23 +1392,22 @@ Status BufferAssigner::AssignPresetBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
|
const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
|
||||||
const HloDataflowAnalysis& dataflow_analysis =
|
|
||||||
alias_analysis.dataflow_analysis();
|
|
||||||
|
|
||||||
for (auto& position_and_chunk : preset_assignments_->chunks()) {
|
for (auto& position_and_chunk : preset_assignments_->chunks()) {
|
||||||
const HloPosition& position = position_and_chunk.first;
|
const HloPosition& defining_position = position_and_chunk.first;
|
||||||
const HloValue& value = dataflow_analysis.GetUniqueValueAt(
|
const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
|
||||||
position.instruction, position.index);
|
defining_position.instruction, defining_position.index);
|
||||||
VLOG(3) << "Preset allocation for value: " << value.ToShortString();
|
for (const HloValue* value : buffer.values()) {
|
||||||
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
|
VLOG(3) << "Preset allocation for value: " << value->ToShortString();
|
||||||
auto preset_allocations_iter = preset_allocations.find(value.color());
|
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
|
||||||
CHECK(preset_allocations_iter != preset_allocations.end())
|
auto preset_allocations_iter = preset_allocations.find(value->color());
|
||||||
<< "No preset value allocation for color " << value.color() << " for "
|
CHECK(preset_allocations_iter != preset_allocations.end())
|
||||||
<< value.ToShortString() << " found.";
|
<< "No preset value allocation for color " << value->color()
|
||||||
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
|
<< " for " << value->ToShortString() << " found.";
|
||||||
chunk.size);
|
preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
|
||||||
|
chunk.size);
|
||||||
|
}
|
||||||
|
|
||||||
const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
|
|
||||||
assigned_buffers->insert(&buffer);
|
assigned_buffers->insert(&buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -835,13 +835,6 @@ TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
|
|||||||
// Set only one preset assignment for while data and its aliases.
|
// Set only one preset assignment for while data and its aliases.
|
||||||
auto preset_assignments = absl::make_unique<PresetAssignments>();
|
auto preset_assignments = absl::make_unique<PresetAssignments>();
|
||||||
preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
|
preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
|
||||||
preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40});
|
|
||||||
preset_assignments->add_chunk({cond_param, {1}},
|
|
||||||
{/*offset=*/100, /*size=*/40});
|
|
||||||
preset_assignments->add_chunk({body_param, {1}},
|
|
||||||
{/*offset=*/100, /*size=*/40});
|
|
||||||
preset_assignments->add_chunk({body_data_next, {}},
|
|
||||||
{/*offset=*/100, /*size=*/40});
|
|
||||||
preset_assignments->assignment_information_for_space(/*memory_space=*/1)
|
preset_assignments->assignment_information_for_space(/*memory_space=*/1)
|
||||||
->size = 140;
|
->size = 140;
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -363,7 +363,7 @@ class MemorySpaceAssignment {
|
|||||||
class Allocation {
|
class Allocation {
|
||||||
public:
|
public:
|
||||||
Allocation(HloPosition defining_position, MemorySpace memory_space,
|
Allocation(HloPosition defining_position, MemorySpace memory_space,
|
||||||
Chunk chunk, int64 start_time, int64 end_time)
|
absl::optional<Chunk> chunk, int64 start_time, int64 end_time)
|
||||||
: defining_position_(defining_position),
|
: defining_position_(defining_position),
|
||||||
memory_space_(memory_space),
|
memory_space_(memory_space),
|
||||||
chunk_(chunk),
|
chunk_(chunk),
|
||||||
@ -393,7 +393,7 @@ class MemorySpaceAssignment {
|
|||||||
|
|
||||||
const std::vector<HloUse>& uses() const { return uses_; }
|
const std::vector<HloUse>& uses() const { return uses_; }
|
||||||
MemorySpace memory_space() const { return memory_space_; }
|
MemorySpace memory_space() const { return memory_space_; }
|
||||||
Chunk chunk() const { return chunk_; }
|
Chunk chunk() const { return *chunk_; }
|
||||||
void set_start_time(int64 start_time) { start_time_ = start_time; }
|
void set_start_time(int64 start_time) { start_time_ = start_time; }
|
||||||
int64 start_time() const { return start_time_; }
|
int64 start_time() const { return start_time_; }
|
||||||
int64 end_time() const { return end_time_; }
|
int64 end_time() const { return end_time_; }
|
||||||
@ -405,10 +405,14 @@ class MemorySpaceAssignment {
|
|||||||
HloInstruction* tuple,
|
HloInstruction* tuple,
|
||||||
ShapeIndex shape_index);
|
ShapeIndex shape_index);
|
||||||
|
|
||||||
|
// Recursively create kGetTupleElement instructions if the defining position
|
||||||
|
// shape is not an array. Returns the new instruction that has array shape.
|
||||||
|
HloInstruction* AddGetTupleElements();
|
||||||
|
|
||||||
HloPosition defining_position_;
|
HloPosition defining_position_;
|
||||||
std::vector<HloUse> uses_;
|
std::vector<HloUse> uses_;
|
||||||
MemorySpace memory_space_;
|
MemorySpace memory_space_;
|
||||||
Chunk chunk_;
|
absl::optional<Chunk> chunk_;
|
||||||
int64 start_time_;
|
int64 start_time_;
|
||||||
int64 end_time_;
|
int64 end_time_;
|
||||||
};
|
};
|
||||||
@ -417,8 +421,8 @@ class MemorySpaceAssignment {
|
|||||||
class CopyAllocation : public Allocation {
|
class CopyAllocation : public Allocation {
|
||||||
public:
|
public:
|
||||||
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
|
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
|
||||||
Chunk chunk, int64 start_time, int64 end_time,
|
absl::optional<Chunk> chunk, int64 start_time,
|
||||||
int64 copy_done_schedule_before_time)
|
int64 end_time, int64 copy_done_schedule_before_time)
|
||||||
: Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
|
: Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
|
||||||
start_time, end_time),
|
start_time, end_time),
|
||||||
prev_allocation_(prev_allocation),
|
prev_allocation_(prev_allocation),
|
||||||
@ -476,11 +480,105 @@ class MemorySpaceAssignment {
|
|||||||
};
|
};
|
||||||
|
|
||||||
using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
|
using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
|
||||||
struct ValueAndAllocationSequence {
|
// AllocationValue is used to break up HloValues for each non-trivial position
|
||||||
const HloValue* value;
|
// (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An
|
||||||
AllocationSequence sequence;
|
// HloValue may include positions and uses that alias with each other across
|
||||||
|
// multiple computations. We use this class to break these HloValues such that
|
||||||
|
// every AllocationValue has one defining position (that may alias with other
|
||||||
|
// AllocationValues). The uses field of the AllocationValue contains only the
|
||||||
|
// direct uses of the AllocationValue's defining position.
|
||||||
|
//
|
||||||
|
// For example, consider the following HLO snippet:
|
||||||
|
//
|
||||||
|
// Body {
|
||||||
|
// body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
// get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param),
|
||||||
|
// index=0
|
||||||
|
// ...
|
||||||
|
// ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Cond {
|
||||||
|
// cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// add.4 = f32[4,3]{1,0} add(...)
|
||||||
|
// tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...)
|
||||||
|
// while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond
|
||||||
|
// get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0
|
||||||
|
// add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...)
|
||||||
|
//
|
||||||
|
// This contains an HloValue that looks like the following:
|
||||||
|
// positions:
|
||||||
|
// add.4
|
||||||
|
// body_param {0}
|
||||||
|
// get-tuple-element.3
|
||||||
|
// tuple {0}
|
||||||
|
// cond_param {0}
|
||||||
|
// tuple.1 {0}
|
||||||
|
// while {0}
|
||||||
|
// get-tuple-element.5
|
||||||
|
// uses:
|
||||||
|
// add.1, operand 0
|
||||||
|
// tuple, operand 0
|
||||||
|
// while, operand 0 {0}
|
||||||
|
// add.5, operand 0
|
||||||
|
//
|
||||||
|
// We break this HloValue up into the following AllocationValues for each
|
||||||
|
// non-trivial position:
|
||||||
|
// AllocationValue1: computation = Entry
|
||||||
|
// position:
|
||||||
|
// add.4
|
||||||
|
// uses:
|
||||||
|
// while, operand 0 {0}
|
||||||
|
// AllocationValue2: computation = Cond
|
||||||
|
// position:
|
||||||
|
// cond_param {0}
|
||||||
|
// uses:
|
||||||
|
// AllocationValue3: computation = Body
|
||||||
|
// position:
|
||||||
|
// body_param {0}
|
||||||
|
// uses:
|
||||||
|
// add.1, operand 0
|
||||||
|
// tuple, operand 0
|
||||||
|
// AllocationValue4: computation = Entry
|
||||||
|
// position:
|
||||||
|
// while {0}
|
||||||
|
// uses:
|
||||||
|
// add.5, operand 0
|
||||||
|
class AllocationValue {
|
||||||
|
public:
|
||||||
|
AllocationValue(const HloValue* value, const HloPosition& position)
|
||||||
|
: value_(value), defining_position_(position) {}
|
||||||
|
|
||||||
|
const HloPosition& defining_position() const { return defining_position_; }
|
||||||
|
const HloInstruction* defining_instruction() const {
|
||||||
|
return defining_position().instruction;
|
||||||
|
}
|
||||||
|
const std::vector<HloUse>& uses() const { return uses_; }
|
||||||
|
const std::vector<int64>& use_times() const { return use_times_; }
|
||||||
|
const HloValue* value() const { return value_; }
|
||||||
|
const HloComputation* computation() const {
|
||||||
|
return defining_instruction()->parent();
|
||||||
|
}
|
||||||
|
AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
|
||||||
|
|
||||||
|
void AddUse(const HloUse& use, int64 use_time) {
|
||||||
|
uses_.push_back(use);
|
||||||
|
use_times_.push_back(use_time);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
std::string ToShortString() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const HloValue* value_;
|
||||||
|
HloPosition defining_position_;
|
||||||
|
std::vector<HloUse> uses_;
|
||||||
|
std::vector<int64> use_times_;
|
||||||
|
AllocationSequence allocation_sequence_;
|
||||||
};
|
};
|
||||||
using AllocationSequenceList = std::vector<ValueAndAllocationSequence>;
|
|
||||||
|
|
||||||
// Runs the MemorySpaceAssignment pass.
|
// Runs the MemorySpaceAssignment pass.
|
||||||
static StatusOr<std::unique_ptr<PresetAssignments>> Run(
|
static StatusOr<std::unique_ptr<PresetAssignments>> Run(
|
||||||
@ -545,7 +643,7 @@ class MemorySpaceAssignment {
|
|||||||
Options options_;
|
Options options_;
|
||||||
std::vector<HloInstruction*> flattened_instructions_;
|
std::vector<HloInstruction*> flattened_instructions_;
|
||||||
absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
|
absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
|
||||||
AllocationSequenceList allocation_sequence_list_;
|
AllocationSequence allocations_;
|
||||||
std::unique_ptr<PresetAssignments> preset_assignments_;
|
std::unique_ptr<PresetAssignments> preset_assignments_;
|
||||||
|
|
||||||
// These maps hold vectors of new instructions that need to be scheduled after
|
// These maps hold vectors of new instructions that need to be scheduled after
|
||||||
@ -562,6 +660,7 @@ class MemorySpaceAssignment {
|
|||||||
struct RequiredMemoryAssignment {
|
struct RequiredMemoryAssignment {
|
||||||
MemorySpaceAssignment::MemorySpace memory_space;
|
MemorySpaceAssignment::MemorySpace memory_space;
|
||||||
int64 time;
|
int64 time;
|
||||||
|
absl::optional<HeapSimulator::Chunk> chunk;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A struct representing an asynchronous copy with its logical start and end
|
// A struct representing an asynchronous copy with its logical start and end
|
||||||
@ -614,14 +713,15 @@ class AsynchronousCopyOrdering {
|
|||||||
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||||
public:
|
public:
|
||||||
using MemorySpace = MemorySpaceAssignment::MemorySpace;
|
using MemorySpace = MemorySpaceAssignment::MemorySpace;
|
||||||
|
using AllocationValue = MemorySpaceAssignment::AllocationValue;
|
||||||
|
|
||||||
AlternateMemoryBestFitHeap(
|
AlternateMemoryBestFitHeap(
|
||||||
MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list,
|
MemorySpaceAssignment::AllocationSequence* allocations,
|
||||||
const MemorySpaceAssignment::Options& options,
|
const MemorySpaceAssignment::Options& options,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const HloLiveRange& hlo_live_range)
|
const HloLiveRange& hlo_live_range)
|
||||||
: GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
|
: GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
|
||||||
allocation_sequence_list_(allocation_sequence_list),
|
allocations_(allocations),
|
||||||
options_(options),
|
options_(options),
|
||||||
alias_analysis_(alias_analysis),
|
alias_analysis_(alias_analysis),
|
||||||
hlo_live_range_(hlo_live_range) {
|
hlo_live_range_(hlo_live_range) {
|
||||||
@ -632,7 +732,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allocates a buffer in preferred memory with whole program lifetime and
|
// Allocates a buffer in preferred memory with whole program lifetime and
|
||||||
// enables prefetching prefech_candidate from default memory across program
|
// enables prefetching prefetch_candidate from default memory across program
|
||||||
// boundaries.
|
// boundaries.
|
||||||
void AllocateCrossProgramPrefetchBuffer(
|
void AllocateCrossProgramPrefetchBuffer(
|
||||||
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
|
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
|
||||||
@ -660,12 +760,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
struct AllocationRequest {
|
struct AllocationRequest {
|
||||||
int64 start_time;
|
int64 start_time;
|
||||||
int64 end_time;
|
int64 end_time;
|
||||||
const std::vector<int64>* use_times;
|
|
||||||
int64 latest_prefetch_time;
|
int64 latest_prefetch_time;
|
||||||
int64 size;
|
int64 size;
|
||||||
|
absl::optional<int64> preferred_offset;
|
||||||
HloUse use;
|
HloUse use;
|
||||||
const HloValue* buffer;
|
MemorySpaceAssignment::AllocationValue* allocation_value;
|
||||||
MemorySpaceAssignment::AllocationSequence* allocations;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Given an allocation sequence, returns the live allocation at time with a
|
// Given an allocation sequence, returns the live allocation at time with a
|
||||||
@ -674,15 +773,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
||||||
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
||||||
|
|
||||||
// Returns true if a buffer is required to be in default memory at a
|
// Returns the required assignment at a particular time, if available.
|
||||||
// particular time. A buffer may be required to be in default memory because
|
absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
|
||||||
// it is a parameter in default memory or an ouput in default memory.
|
const HloValue* buffer, int64 time) const;
|
||||||
bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
|
|
||||||
|
|
||||||
// Returns true if this buffer is allowed to be placed in the alternate
|
// Returns true if this buffer is allowed to be placed in the alternate
|
||||||
// memory.
|
// memory.
|
||||||
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
||||||
|
|
||||||
|
// Returns true if the use is allowed in the alternate memory.
|
||||||
|
bool IsUseAllowedInAlternateMemory(const HloUse& use) const;
|
||||||
|
|
||||||
|
// Given an HloValue, creates AllocationValue objects and corresponding
|
||||||
|
// AllocationSequences and appends them into allocation_sequence_list_.
|
||||||
|
void CreateAllocationValues(const HloValue* value,
|
||||||
|
std::vector<AllocationValue>* allocation_values);
|
||||||
|
|
||||||
// Finds an allocation for the given interval.
|
// Finds an allocation for the given interval.
|
||||||
//
|
//
|
||||||
// It performs three things in the following order:
|
// It performs three things in the following order:
|
||||||
@ -715,10 +821,21 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
// availability if no preferred offset is given, or at the preferred_offset if
|
// availability if no preferred offset is given, or at the preferred_offset if
|
||||||
// it is given.
|
// it is given.
|
||||||
absl::optional<ChunkCandidate> FindBestChunkCandidate(
|
absl::optional<ChunkCandidate> FindBestChunkCandidate(
|
||||||
int64 end_time, const std::vector<int64>& use_times,
|
const AllocationRequest& request, absl::optional<int64> preferred_offset,
|
||||||
absl::optional<int64> preferred_offset,
|
|
||||||
BufferInterval* alternate_mem_interval) const;
|
BufferInterval* alternate_mem_interval) const;
|
||||||
|
|
||||||
|
// At the end of an allocation with a sequential call (while, conditional, and
|
||||||
|
// call), this function adds the necessary aliased assignments within the
|
||||||
|
// called computations.
|
||||||
|
void AddAliasedRequiredAssignmentsForSequentialCall(
|
||||||
|
const HloUse& use,
|
||||||
|
const MemorySpaceAssignment::Allocation* aliased_allocation);
|
||||||
|
|
||||||
|
// Propagates aliased required assignment for a given position.
|
||||||
|
void AddAliasedRequiredAssignment(
|
||||||
|
const HloInstruction* instruction, ShapeIndex index,
|
||||||
|
const MemorySpaceAssignment::Allocation* aliased_allocation);
|
||||||
|
|
||||||
// Adds input and outputs as required assignments.
|
// Adds input and outputs as required assignments.
|
||||||
void AddInputAndOutputRequiredAssignments();
|
void AddInputAndOutputRequiredAssignments();
|
||||||
|
|
||||||
@ -734,7 +851,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
std::vector<const BufferInterval*> GetSortedColocatedIntervals(
|
std::vector<const BufferInterval*> GetSortedColocatedIntervals(
|
||||||
const BufferInterval& interval) const;
|
const BufferInterval& interval) const;
|
||||||
|
|
||||||
// Since the allocations are recorded to the AllocationSequenceList, we don't
|
// Since the allocations are recorded to the AllocationSequence, we don't
|
||||||
// maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
|
// maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
|
||||||
// to avoid unnecessarily adding the chunk to the chunk map.
|
// to avoid unnecessarily adding the chunk to the chunk map.
|
||||||
void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
|
void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
|
||||||
@ -749,8 +866,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
|
|
||||||
// Adds an asynchronous copy to the allocations.
|
// Adds an asynchronous copy to the allocations.
|
||||||
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
||||||
MemorySpace memory_space, Chunk chunk, int64 start_time,
|
MemorySpace memory_space, absl::optional<Chunk> chunk,
|
||||||
int64 end_time, int64 copy_done_schedule_before_time,
|
int64 start_time, int64 end_time,
|
||||||
|
int64 copy_done_schedule_before_time,
|
||||||
MemorySpaceAssignment::AllocationSequence* allocations);
|
MemorySpaceAssignment::AllocationSequence* allocations);
|
||||||
|
|
||||||
// This method is used for committing the chunk candidate but adding it to
|
// This method is used for committing the chunk candidate but adding it to
|
||||||
@ -768,7 +886,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
return options_.max_size_in_bytes - reserved_in_bytes_;
|
return options_.max_size_in_bytes - reserved_in_bytes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list_;
|
MemorySpaceAssignment::AllocationSequence* allocations_;
|
||||||
const MemorySpaceAssignment::Options& options_;
|
const MemorySpaceAssignment::Options& options_;
|
||||||
const HloAliasAnalysis& alias_analysis_;
|
const HloAliasAnalysis& alias_analysis_;
|
||||||
const HloLiveRange& hlo_live_range_;
|
const HloLiveRange& hlo_live_range_;
|
||||||
@ -784,6 +902,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
required_assignments_;
|
required_assignments_;
|
||||||
// Number of bytes reserved in alternate memory space.
|
// Number of bytes reserved in alternate memory space.
|
||||||
int64 reserved_in_bytes_ = 0;
|
int64 reserved_in_bytes_ = 0;
|
||||||
|
int64 global_max_time_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -103,6 +103,21 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Only check parameters in default memory if the original module didn't
|
||||||
|
// have the parameters in alternate memory.
|
||||||
|
bool check_parameters_in_default_memory = true;
|
||||||
|
for (const HloInstruction* parameter :
|
||||||
|
module->entry_computation()->parameter_instructions()) {
|
||||||
|
ShapeUtil::ForEachSubshape(
|
||||||
|
parameter->shape(),
|
||||||
|
[&](const Shape& subshape, const ShapeIndex& /*index*/) {
|
||||||
|
if (subshape.has_layout() &&
|
||||||
|
subshape.layout().memory_space() == kAlternateMemorySpace) {
|
||||||
|
check_parameters_in_default_memory = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
MemorySpaceAssignment::Options options;
|
MemorySpaceAssignment::Options options;
|
||||||
options.alternate_memory_space = kAlternateMemorySpace;
|
options.alternate_memory_space = kAlternateMemorySpace;
|
||||||
options.max_size_in_bytes = 128;
|
options.max_size_in_bytes = 128;
|
||||||
@ -125,6 +140,9 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||||||
MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
|
MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
|
||||||
options)
|
options)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
|
if (check_parameters_in_default_memory) {
|
||||||
|
CheckParametersInDefaultMemory(module);
|
||||||
|
}
|
||||||
CheckPresetAssignments(preset_assignments.get());
|
CheckPresetAssignments(preset_assignments.get());
|
||||||
return preset_assignments;
|
return preset_assignments;
|
||||||
}
|
}
|
||||||
@ -148,6 +166,24 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CheckParametersInDefaultMemory(const HloModule* module) {
|
||||||
|
// Check that all the entry parameter subshapes are placed in default
|
||||||
|
// memory.
|
||||||
|
const HloComputation* entry_computation = module->entry_computation();
|
||||||
|
for (const HloInstruction* parameter :
|
||||||
|
entry_computation->parameter_instructions()) {
|
||||||
|
ShapeUtil::ForEachSubshape(
|
||||||
|
parameter->shape(),
|
||||||
|
[&](const Shape& subshape, const ShapeIndex& /*index*/) {
|
||||||
|
if (subshape.has_layout()) {
|
||||||
|
EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace)
|
||||||
|
<< "Parameter not in default memory: "
|
||||||
|
<< parameter->ToString();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
@ -1250,11 +1286,282 @@ TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) {
|
|||||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||||
const Shape& while_subshape =
|
const Shape& while_subshape =
|
||||||
ShapeUtil::GetSubshape(instruction->shape(), {0});
|
ShapeUtil::GetSubshape(instruction->shape(), {0});
|
||||||
EXPECT_NE(while_subshape.layout().memory_space(), kAlternateMemorySpace);
|
// We expect shape {0} to either be in default memory for the entire while
|
||||||
|
// loop or there has to be an eviction within the while loop.
|
||||||
|
if (while_subshape.layout().memory_space() == kAlternateMemorySpace) {
|
||||||
|
const HloInstruction* body_param =
|
||||||
|
instruction->while_body()->parameter_instruction(0);
|
||||||
|
const HloInstruction* gte = nullptr;
|
||||||
|
for (const HloInstruction* user : body_param->users()) {
|
||||||
|
if (user->opcode() == HloOpcode::kGetTupleElement &&
|
||||||
|
user->tuple_index() == 0) {
|
||||||
|
gte = user;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_NE(gte, nullptr);
|
||||||
|
const HloInstruction* copy_start = nullptr;
|
||||||
|
for (const HloInstruction* user : gte->users()) {
|
||||||
|
if (user->opcode() == HloOpcode::kCopyStart) {
|
||||||
|
copy_start = user;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_NE(copy_start, nullptr);
|
||||||
|
const Shape& copy_start_subshape =
|
||||||
|
ShapeUtil::GetSubshape(copy_start->shape(), {0});
|
||||||
|
|
||||||
|
EXPECT_NE(copy_start_subshape.layout().memory_space(),
|
||||||
|
kAlternateMemorySpace);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule WhileAllocationBug, is_scheduled=true
|
||||||
|
|
||||||
|
%WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
|
||||||
|
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
|
||||||
|
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
|
||||||
|
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
|
||||||
|
%constant.1 = f32[] constant(1)
|
||||||
|
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
|
||||||
|
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
|
||||||
|
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
|
||||||
|
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
|
||||||
|
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
|
||||||
|
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
|
||||||
|
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
|
||||||
|
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
|
||||||
|
%constant = f32[] constant(50)
|
||||||
|
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
|
||||||
|
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
|
||||||
|
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
|
||||||
|
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
|
||||||
|
%constant.1 = f32[] constant(1)
|
||||||
|
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
|
||||||
|
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
|
||||||
|
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
|
||||||
|
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
|
||||||
|
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
|
||||||
|
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
|
||||||
|
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
|
||||||
|
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
|
||||||
|
%constant = f32[] constant(50)
|
||||||
|
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
|
||||||
|
%param_iter = f32[] parameter(1)
|
||||||
|
%param_data = f32[4,3]{1,0} parameter(0)
|
||||||
|
%p2 = f32[4,3]{1,0} parameter(2)
|
||||||
|
%neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
|
||||||
|
%neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
|
||||||
|
%neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
|
||||||
|
%neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
|
||||||
|
%neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
|
||||||
|
%neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
|
||||||
|
%neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
|
||||||
|
%add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
|
||||||
|
%tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
|
||||||
|
%while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
|
||||||
|
%get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
|
||||||
|
%add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
|
||||||
|
%get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
|
||||||
|
%tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter)
|
||||||
|
%while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
|
||||||
|
%get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
|
||||||
|
ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
AssignMemorySpace(module.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) {
|
||||||
|
// Tests against while live ranges being incorrect and the verifier
|
||||||
|
// complaining about a conflict.
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule WhileAllocationBug, is_scheduled=true
|
||||||
|
|
||||||
|
%WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
|
||||||
|
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
|
||||||
|
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
|
||||||
|
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
|
||||||
|
%neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
|
||||||
|
%neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
|
||||||
|
%neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
|
||||||
|
%neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
|
||||||
|
%neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
|
||||||
|
%neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
|
||||||
|
%neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
|
||||||
|
%neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
|
||||||
|
%neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
|
||||||
|
%neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
|
||||||
|
%neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
|
||||||
|
%constant.1 = f32[] constant(1)
|
||||||
|
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
|
||||||
|
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
|
||||||
|
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
|
||||||
|
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
|
||||||
|
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
|
||||||
|
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
|
||||||
|
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
|
||||||
|
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
|
||||||
|
%constant = f32[] constant(50)
|
||||||
|
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
|
||||||
|
%param_iter = f32[] parameter(1)
|
||||||
|
%param_data = f32[4,3]{1,0} parameter(0)
|
||||||
|
%p2 = f32[4,3]{1,0} parameter(2)
|
||||||
|
%neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
|
||||||
|
%neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
|
||||||
|
%neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
|
||||||
|
%neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
|
||||||
|
%neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
|
||||||
|
%neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
|
||||||
|
%neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
|
||||||
|
%add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
|
||||||
|
%tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
|
||||||
|
%while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
|
||||||
|
%get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
|
||||||
|
%get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
|
||||||
|
%add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
|
||||||
|
ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
AssignMemorySpace(module.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
|
||||||
|
// Tests against a bug when there are consecutive while loops with one buffer
|
||||||
|
// (the value doesn't change in the buffer), the parameter can be colored in
|
||||||
|
// the alternate memory space.
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule WhileAllocationBug, is_scheduled=true
|
||||||
|
|
||||||
|
%WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
|
||||||
|
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
|
||||||
|
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
|
||||||
|
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
|
||||||
|
%neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
|
||||||
|
%neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
|
||||||
|
%neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
|
||||||
|
%neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
|
||||||
|
%neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
|
||||||
|
%neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
|
||||||
|
%neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
|
||||||
|
%neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
|
||||||
|
%neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
|
||||||
|
%neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
|
||||||
|
%neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
|
||||||
|
%constant.1 = f32[] constant(1)
|
||||||
|
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
|
||||||
|
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
|
||||||
|
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
|
||||||
|
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
|
||||||
|
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
|
||||||
|
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
|
||||||
|
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
|
||||||
|
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
|
||||||
|
%constant = f32[] constant(50)
|
||||||
|
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
|
||||||
|
%body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
|
||||||
|
%get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
|
||||||
|
%get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
|
||||||
|
%neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
|
||||||
|
%neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
|
||||||
|
%neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
|
||||||
|
%neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
|
||||||
|
%neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
|
||||||
|
%neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
|
||||||
|
%neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
|
||||||
|
%neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
|
||||||
|
%neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
|
||||||
|
%neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
|
||||||
|
%neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
|
||||||
|
%constant.1 = f32[] constant(1)
|
||||||
|
%add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
|
||||||
|
%constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
|
||||||
|
%multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
|
||||||
|
%multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
|
||||||
|
%add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
|
||||||
|
%add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
|
||||||
|
ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
|
||||||
|
}
|
||||||
|
|
||||||
|
%WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
|
||||||
|
%cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
|
||||||
|
%get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
|
||||||
|
%constant = f32[] constant(50)
|
||||||
|
ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
|
||||||
|
%param_iter = f32[] parameter(1)
|
||||||
|
%param_data = f32[4,3]{1,0} parameter(0)
|
||||||
|
%p2 = f32[4,3]{1,0} parameter(2)
|
||||||
|
%neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
|
||||||
|
%neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
|
||||||
|
%neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
|
||||||
|
%neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
|
||||||
|
%neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
|
||||||
|
%neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
|
||||||
|
%neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
|
||||||
|
%add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
|
||||||
|
%tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
|
||||||
|
%while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
|
||||||
|
%get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
|
||||||
|
%add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
|
||||||
|
%tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter)
|
||||||
|
%while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
|
||||||
|
%get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
|
||||||
|
%get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1
|
||||||
|
ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
AssignMemorySpace(module.get());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
||||||
// Having control_predecessors on an HLO was preventing us from DCEing an op
|
// Having control_predecessors on an HLO was preventing us from DCEing an op
|
||||||
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
|
// that doesn't have any users (tuple.1). The scheduler assumes the graph is
|
||||||
@ -2070,12 +2377,6 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
|
|||||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||||
/*max_prefetch_interval=*/25);
|
/*max_prefetch_interval=*/25);
|
||||||
|
|
||||||
int64 memory_space_across_while = kDefaultMemorySpace;
|
|
||||||
bool allocate_across_sequential_calls = GetParam();
|
|
||||||
if (allocate_across_sequential_calls) {
|
|
||||||
memory_space_across_while = kAlternateMemorySpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index {0} of the while loop argument is not written inside the while loop,
|
// Index {0} of the while loop argument is not written inside the while loop,
|
||||||
// so it can be trivially placed in the alternate memory space.
|
// so it can be trivially placed in the alternate memory space.
|
||||||
*ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
|
*ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
|
||||||
@ -2087,12 +2388,11 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
|
|||||||
LayoutUtil::MakeLayout(
|
LayoutUtil::MakeLayout(
|
||||||
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
kDefaultMemorySpace);
|
kDefaultMemorySpace);
|
||||||
// Index {2} of the while loop argument is placed in the alternate memory if
|
// Index {2} of the while loop is placed in the default memory.
|
||||||
// we enable the allocate_across_sequential_calls option.
|
|
||||||
*ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
|
*ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
|
||||||
LayoutUtil::MakeLayout(
|
LayoutUtil::MakeLayout(
|
||||||
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
|
||||||
memory_space_across_while);
|
kDefaultMemorySpace);
|
||||||
|
|
||||||
// Expect the layout for the while loop and its aliased buffers.
|
// Expect the layout for the while loop and its aliased buffers.
|
||||||
EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
|
EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user