[XLA] Add a verifier to memory space assignment to check against overlapping buffers
PiperOrigin-RevId: 287573437 Change-Id: Ic28ee1bd4ff191c2194fb001880530eeeb4acea2
This commit is contained in:
parent
80c685a55f
commit
aa38702d75
@ -31,6 +31,12 @@ namespace xla {
|
||||
using absl::flat_hash_map;
|
||||
using absl::flat_hash_set;
|
||||
|
||||
bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
|
||||
CHECK_NE(size, 0);
|
||||
CHECK_NE(other_chunk.size, 0);
|
||||
return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
|
||||
}
|
||||
|
||||
/*static*/
|
||||
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
|
||||
const HloSchedule& schedule,
|
||||
@ -591,8 +597,7 @@ void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
|
||||
|
||||
using Chunk = HeapSimulator::Chunk;
|
||||
|
||||
void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
|
||||
int64 start, int64 end, const Chunk& chunk) {
|
||||
void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
|
||||
node_storage_.emplace_back(
|
||||
BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
|
||||
|
||||
@ -620,8 +625,7 @@ void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Chunk>
|
||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::ChunksOverlappingInTime(
|
||||
std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
|
||||
int64 start, int64 end) const {
|
||||
std::vector<Chunk> result;
|
||||
if (node_storage_.empty()) {
|
||||
|
@ -57,6 +57,8 @@ class HeapSimulator {
|
||||
int64 size;
|
||||
|
||||
int64 chunk_end() const { return offset + size; }
|
||||
|
||||
bool OverlapsWith(Chunk other_chunk) const;
|
||||
};
|
||||
|
||||
// Result represents the result of the heap simulation.
|
||||
@ -284,6 +286,39 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
|
||||
int64 max_heap_size_ = 0;
|
||||
};
|
||||
|
||||
// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
|
||||
// and the chunk assigned to it.
|
||||
struct BufferIntervalTreeNode {
|
||||
// Alloc time.
|
||||
int64 start;
|
||||
// Free time.
|
||||
int64 end;
|
||||
// Maximum free time of all nodes in the subtree where this node is the root.
|
||||
int64 subtree_end;
|
||||
// Allocated chunk for the buffer.
|
||||
HeapSimulator::Chunk chunk;
|
||||
// Left child.
|
||||
BufferIntervalTreeNode* left;
|
||||
// Right child.
|
||||
BufferIntervalTreeNode* right;
|
||||
};
|
||||
|
||||
// An interval tree that can query buffers overlapping in time.
|
||||
class BufferIntervalTree {
|
||||
public:
|
||||
using Chunk = HeapSimulator::Chunk;
|
||||
// Adds a buffer to the interval tree, with the time interval and allocated
|
||||
// chunk specified.
|
||||
void Add(int64 start, int64 end, const Chunk& chunk);
|
||||
|
||||
// Returns vector of allocated chunks that overlap with the given time
|
||||
// interval.
|
||||
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
|
||||
|
||||
private:
|
||||
std::list<BufferIntervalTreeNode> node_storage_;
|
||||
};
|
||||
|
||||
// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
|
||||
// then allocates them in decreasing spatial or temporal size regardless of the
|
||||
// alloc/free time. It internally tracks the allocated buffers and their live
|
||||
@ -334,39 +369,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
static BufferIntervalCompare GetSpatialBufferIntervalCompare();
|
||||
|
||||
protected:
|
||||
// Node in BufferIntervalTree that stores the alloc and free times of a
|
||||
// buffer, and the chunk assigned to it.
|
||||
struct BufferIntervalTreeNode {
|
||||
// Alloc time.
|
||||
int64 start;
|
||||
// Free time.
|
||||
int64 end;
|
||||
// Maximum free time of all nodes in the subtree where this node is the
|
||||
// root.
|
||||
int64 subtree_end;
|
||||
// Allocated chunk for the buffer.
|
||||
HeapSimulator::Chunk chunk;
|
||||
// Left child.
|
||||
BufferIntervalTreeNode* left;
|
||||
// Right child.
|
||||
BufferIntervalTreeNode* right;
|
||||
};
|
||||
|
||||
// An interval tree that can query buffers overlapping in time.
|
||||
class BufferIntervalTree {
|
||||
public:
|
||||
// Adds a buffer to the interval tree, with the time interval and allocated
|
||||
// chunk specified.
|
||||
void Add(int64 start, int64 end, const Chunk& chunk);
|
||||
|
||||
// Returns vector of allocated chunks that overlap with the given time
|
||||
// interval.
|
||||
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
|
||||
|
||||
private:
|
||||
std::list<BufferIntervalTreeNode> node_storage_;
|
||||
};
|
||||
|
||||
// The candidate contains a chunk and the resultant heap size if this
|
||||
// chunk is to be committed.
|
||||
struct ChunkCandidate {
|
||||
|
@ -1095,6 +1095,10 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) {
|
||||
VLOG(1) << "Maximum number of outstanding async copies: "
|
||||
<< CountMaximumOutstandingAsyncCopies(*module);
|
||||
|
||||
if (options.verify || VLOG_IS_ON(1)) {
|
||||
TF_RETURN_IF_ERROR(memory_space_assignment.Verify());
|
||||
}
|
||||
|
||||
return std::move(memory_space_assignment.preset_assignments_);
|
||||
}
|
||||
|
||||
@ -1509,4 +1513,60 @@ Status MemorySpaceAssignment::FixSchedule() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemorySpaceAssignment::Verify() const {
|
||||
VLOG(3) << "Verifying:";
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module_));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||
HloLiveRange::Run(module_->schedule(), *alias_analysis,
|
||||
module_->entry_computation()));
|
||||
|
||||
BufferIntervalTree interval_tree;
|
||||
absl::flat_hash_set<int64> seen_buffers;
|
||||
|
||||
for (const auto& position_and_chunk : preset_assignments_->chunks()) {
|
||||
const HloPosition& position = position_and_chunk.first;
|
||||
const Chunk& chunk = position_and_chunk.second;
|
||||
const HloBuffer& buffer =
|
||||
alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
|
||||
if (seen_buffers.contains(buffer.id())) {
|
||||
continue;
|
||||
}
|
||||
seen_buffers.insert(buffer.id());
|
||||
|
||||
int64 start_time = INT64_MAX;
|
||||
int64 end_time = -1;
|
||||
for (const HloValue* value : buffer.values()) {
|
||||
const HloLiveRange::TimeBound& time_bound =
|
||||
hlo_live_range->buffer_live_ranges().at(value);
|
||||
start_time = std::min(start_time, time_bound.start);
|
||||
end_time = std::max(end_time, time_bound.end);
|
||||
}
|
||||
CHECK_GE(start_time, 0);
|
||||
CHECK_GT(end_time, 0);
|
||||
// Get the chunks overlapping in time and search if they overlap in space as
|
||||
// well.
|
||||
// TODO(berkin): For now checking against end_time - 1 (exclusive), but we
|
||||
// really should check against end_time (inclusive) for cases where the
|
||||
// operand can't share buffer with user (see
|
||||
// HloDataflowAnalysis::CanShareOperandBufferWithUser).
|
||||
for (const Chunk& overlapping_chunk :
|
||||
interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
|
||||
if (chunk.OverlapsWith(overlapping_chunk)) {
|
||||
return InternalError(
|
||||
("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
|
||||
" off: %d size: %d"),
|
||||
buffer.ToString(), start_time, end_time, chunk.offset, chunk.size,
|
||||
overlapping_chunk.offset, overlapping_chunk.size);
|
||||
}
|
||||
}
|
||||
interval_tree.Add(start_time, end_time - 1, chunk);
|
||||
VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", "
|
||||
<< end_time << ") off: " << position_and_chunk.second.offset
|
||||
<< ", size: " << position_and_chunk.second.size;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -299,6 +299,10 @@ class MemorySpaceAssignment {
|
||||
// If true, tries allocating buffers across (e.g., before and inside a while
|
||||
// loop body) sequential calls (kWhile, kCall, and kConditional).
|
||||
bool allocate_across_sequential_calls = false;
|
||||
|
||||
// If true, verifies the memory space assignment against overlapping
|
||||
// buffers.
|
||||
bool verify = false;
|
||||
};
|
||||
|
||||
// This class represents an allocation that might either be in the default or
|
||||
@ -471,6 +475,9 @@ class MemorySpaceAssignment {
|
||||
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis);
|
||||
|
||||
// Verify that the memory space assignment is free of overlapping buffers.
|
||||
Status Verify() const;
|
||||
|
||||
private:
|
||||
MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space,
|
||||
const HloLiveRange& hlo_live_range)
|
||||
|
@ -107,6 +107,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
|
||||
options.max_outstanding_async_copies = max_outstanding_async_copies;
|
||||
options.allocate_across_sequential_calls = GetParam();
|
||||
options.verify = true;
|
||||
std::unique_ptr<PresetAssignments> preset_assignments =
|
||||
MemorySpaceAssignment::Run(module, options).ValueOrDie();
|
||||
CheckPresetAssignments(preset_assignments.get());
|
||||
|
Loading…
Reference in New Issue
Block a user