[XLA] Add a verifier to memory space assignment to check against overlapping buffers
PiperOrigin-RevId: 287573437 Change-Id: Ic28ee1bd4ff191c2194fb001880530eeeb4acea2
This commit is contained in:
parent
80c685a55f
commit
aa38702d75
@ -31,6 +31,12 @@ namespace xla {
|
|||||||
using absl::flat_hash_map;
|
using absl::flat_hash_map;
|
||||||
using absl::flat_hash_set;
|
using absl::flat_hash_set;
|
||||||
|
|
||||||
|
bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
|
||||||
|
CHECK_NE(size, 0);
|
||||||
|
CHECK_NE(other_chunk.size, 0);
|
||||||
|
return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
|
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
|
||||||
const HloSchedule& schedule,
|
const HloSchedule& schedule,
|
||||||
@ -591,8 +597,7 @@ void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
|
|||||||
|
|
||||||
using Chunk = HeapSimulator::Chunk;
|
using Chunk = HeapSimulator::Chunk;
|
||||||
|
|
||||||
void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
|
void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
|
||||||
int64 start, int64 end, const Chunk& chunk) {
|
|
||||||
node_storage_.emplace_back(
|
node_storage_.emplace_back(
|
||||||
BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
|
BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
|
||||||
|
|
||||||
@ -620,8 +625,7 @@ void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Chunk>
|
std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
|
||||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::ChunksOverlappingInTime(
|
|
||||||
int64 start, int64 end) const {
|
int64 start, int64 end) const {
|
||||||
std::vector<Chunk> result;
|
std::vector<Chunk> result;
|
||||||
if (node_storage_.empty()) {
|
if (node_storage_.empty()) {
|
||||||
|
@ -57,6 +57,8 @@ class HeapSimulator {
|
|||||||
int64 size;
|
int64 size;
|
||||||
|
|
||||||
int64 chunk_end() const { return offset + size; }
|
int64 chunk_end() const { return offset + size; }
|
||||||
|
|
||||||
|
bool OverlapsWith(Chunk other_chunk) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Result represents the result of the heap simulation.
|
// Result represents the result of the heap simulation.
|
||||||
@ -284,6 +286,39 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
|
|||||||
int64 max_heap_size_ = 0;
|
int64 max_heap_size_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
|
||||||
|
// and the chunk assigned to it.
|
||||||
|
struct BufferIntervalTreeNode {
|
||||||
|
// Alloc time.
|
||||||
|
int64 start;
|
||||||
|
// Free time.
|
||||||
|
int64 end;
|
||||||
|
// Maximum free time of all nodes in the subtree where this node is the root.
|
||||||
|
int64 subtree_end;
|
||||||
|
// Allocated chunk for the buffer.
|
||||||
|
HeapSimulator::Chunk chunk;
|
||||||
|
// Left child.
|
||||||
|
BufferIntervalTreeNode* left;
|
||||||
|
// Right child.
|
||||||
|
BufferIntervalTreeNode* right;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An interval tree that can query buffers overlapping in time.
|
||||||
|
class BufferIntervalTree {
|
||||||
|
public:
|
||||||
|
using Chunk = HeapSimulator::Chunk;
|
||||||
|
// Adds a buffer to the interval tree, with the time interval and allocated
|
||||||
|
// chunk specified.
|
||||||
|
void Add(int64 start, int64 end, const Chunk& chunk);
|
||||||
|
|
||||||
|
// Returns vector of allocated chunks that overlap with the given time
|
||||||
|
// interval.
|
||||||
|
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::list<BufferIntervalTreeNode> node_storage_;
|
||||||
|
};
|
||||||
|
|
||||||
// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
|
// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
|
||||||
// then allocates them in decreasing spatial or temporal size regardless of the
|
// then allocates them in decreasing spatial or temporal size regardless of the
|
||||||
// alloc/free time. It internally tracks the allocated buffers and their live
|
// alloc/free time. It internally tracks the allocated buffers and their live
|
||||||
@ -334,39 +369,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
static BufferIntervalCompare GetSpatialBufferIntervalCompare();
|
static BufferIntervalCompare GetSpatialBufferIntervalCompare();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Node in BufferIntervalTree that stores the alloc and free times of a
|
|
||||||
// buffer, and the chunk assigned to it.
|
|
||||||
struct BufferIntervalTreeNode {
|
|
||||||
// Alloc time.
|
|
||||||
int64 start;
|
|
||||||
// Free time.
|
|
||||||
int64 end;
|
|
||||||
// Maximum free time of all nodes in the subtree where this node is the
|
|
||||||
// root.
|
|
||||||
int64 subtree_end;
|
|
||||||
// Allocated chunk for the buffer.
|
|
||||||
HeapSimulator::Chunk chunk;
|
|
||||||
// Left child.
|
|
||||||
BufferIntervalTreeNode* left;
|
|
||||||
// Right child.
|
|
||||||
BufferIntervalTreeNode* right;
|
|
||||||
};
|
|
||||||
|
|
||||||
// An interval tree that can query buffers overlapping in time.
|
|
||||||
class BufferIntervalTree {
|
|
||||||
public:
|
|
||||||
// Adds a buffer to the interval tree, with the time interval and allocated
|
|
||||||
// chunk specified.
|
|
||||||
void Add(int64 start, int64 end, const Chunk& chunk);
|
|
||||||
|
|
||||||
// Returns vector of allocated chunks that overlap with the given time
|
|
||||||
// interval.
|
|
||||||
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::list<BufferIntervalTreeNode> node_storage_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// The candidate contains a chunk and the resultant heap size if this
|
// The candidate contains a chunk and the resultant heap size if this
|
||||||
// chunk is to be committed.
|
// chunk is to be committed.
|
||||||
struct ChunkCandidate {
|
struct ChunkCandidate {
|
||||||
|
@ -1095,6 +1095,10 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) {
|
|||||||
VLOG(1) << "Maximum number of outstanding async copies: "
|
VLOG(1) << "Maximum number of outstanding async copies: "
|
||||||
<< CountMaximumOutstandingAsyncCopies(*module);
|
<< CountMaximumOutstandingAsyncCopies(*module);
|
||||||
|
|
||||||
|
if (options.verify || VLOG_IS_ON(1)) {
|
||||||
|
TF_RETURN_IF_ERROR(memory_space_assignment.Verify());
|
||||||
|
}
|
||||||
|
|
||||||
return std::move(memory_space_assignment.preset_assignments_);
|
return std::move(memory_space_assignment.preset_assignments_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1509,4 +1513,60 @@ Status MemorySpaceAssignment::FixSchedule() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MemorySpaceAssignment::Verify() const {
|
||||||
|
VLOG(3) << "Verifying:";
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
|
HloAliasAnalysis::Run(module_));
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||||
|
HloLiveRange::Run(module_->schedule(), *alias_analysis,
|
||||||
|
module_->entry_computation()));
|
||||||
|
|
||||||
|
BufferIntervalTree interval_tree;
|
||||||
|
absl::flat_hash_set<int64> seen_buffers;
|
||||||
|
|
||||||
|
for (const auto& position_and_chunk : preset_assignments_->chunks()) {
|
||||||
|
const HloPosition& position = position_and_chunk.first;
|
||||||
|
const Chunk& chunk = position_and_chunk.second;
|
||||||
|
const HloBuffer& buffer =
|
||||||
|
alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
|
||||||
|
if (seen_buffers.contains(buffer.id())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen_buffers.insert(buffer.id());
|
||||||
|
|
||||||
|
int64 start_time = INT64_MAX;
|
||||||
|
int64 end_time = -1;
|
||||||
|
for (const HloValue* value : buffer.values()) {
|
||||||
|
const HloLiveRange::TimeBound& time_bound =
|
||||||
|
hlo_live_range->buffer_live_ranges().at(value);
|
||||||
|
start_time = std::min(start_time, time_bound.start);
|
||||||
|
end_time = std::max(end_time, time_bound.end);
|
||||||
|
}
|
||||||
|
CHECK_GE(start_time, 0);
|
||||||
|
CHECK_GT(end_time, 0);
|
||||||
|
// Get the chunks overlapping in time and search if they overlap in space as
|
||||||
|
// well.
|
||||||
|
// TODO(berkin): For now checking against end_time - 1 (exclusive), but we
|
||||||
|
// really should check against end_time (inclusive) for cases where the
|
||||||
|
// operand can't share buffer with user (see
|
||||||
|
// HloDataflowAnalysis::CanShareOperandBufferWithUser).
|
||||||
|
for (const Chunk& overlapping_chunk :
|
||||||
|
interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
|
||||||
|
if (chunk.OverlapsWith(overlapping_chunk)) {
|
||||||
|
return InternalError(
|
||||||
|
("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
|
||||||
|
" off: %d size: %d"),
|
||||||
|
buffer.ToString(), start_time, end_time, chunk.offset, chunk.size,
|
||||||
|
overlapping_chunk.offset, overlapping_chunk.size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
interval_tree.Add(start_time, end_time - 1, chunk);
|
||||||
|
VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", "
|
||||||
|
<< end_time << ") off: " << position_and_chunk.second.offset
|
||||||
|
<< ", size: " << position_and_chunk.second.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -299,6 +299,10 @@ class MemorySpaceAssignment {
|
|||||||
// If true, tries allocating buffers across (e.g., before and inside a while
|
// If true, tries allocating buffers across (e.g., before and inside a while
|
||||||
// loop body) sequential calls (kWhile, kCall, and kConditional).
|
// loop body) sequential calls (kWhile, kCall, and kConditional).
|
||||||
bool allocate_across_sequential_calls = false;
|
bool allocate_across_sequential_calls = false;
|
||||||
|
|
||||||
|
// If true, verifies the memory space assignment against overlapping
|
||||||
|
// buffers.
|
||||||
|
bool verify = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This class represents an allocation that might either be in the default or
|
// This class represents an allocation that might either be in the default or
|
||||||
@ -471,6 +475,9 @@ class MemorySpaceAssignment {
|
|||||||
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
|
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
|
||||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis);
|
const MemorySpaceAssignmentCostAnalysis& cost_analysis);
|
||||||
|
|
||||||
|
// Verify that the memory space assignment is free of overlapping buffers.
|
||||||
|
Status Verify() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space,
|
MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space,
|
||||||
const HloLiveRange& hlo_live_range)
|
const HloLiveRange& hlo_live_range)
|
||||||
|
@ -107,6 +107,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||||||
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
|
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
|
||||||
options.max_outstanding_async_copies = max_outstanding_async_copies;
|
options.max_outstanding_async_copies = max_outstanding_async_copies;
|
||||||
options.allocate_across_sequential_calls = GetParam();
|
options.allocate_across_sequential_calls = GetParam();
|
||||||
|
options.verify = true;
|
||||||
std::unique_ptr<PresetAssignments> preset_assignments =
|
std::unique_ptr<PresetAssignments> preset_assignments =
|
||||||
MemorySpaceAssignment::Run(module, options).ValueOrDie();
|
MemorySpaceAssignment::Run(module, options).ValueOrDie();
|
||||||
CheckPresetAssignments(preset_assignments.get());
|
CheckPresetAssignments(preset_assignments.get());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user