[XLA] Add a verifier to memory space assignment to check against overlapping buffers

PiperOrigin-RevId: 287573437
Change-Id: Ic28ee1bd4ff191c2194fb001880530eeeb4acea2
This commit is contained in:
Berkin Ilbeyi 2019-12-30 11:15:34 -08:00 committed by Jacques Pienaar
parent 80c685a55f
commit aa38702d75
5 changed files with 111 additions and 37 deletions

View File

@ -31,6 +31,12 @@ namespace xla {
using absl::flat_hash_map;
using absl::flat_hash_set;
bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
CHECK_NE(size, 0);
CHECK_NE(other_chunk.size, 0);
return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
}
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
const HloSchedule& schedule,
@ -591,8 +597,7 @@ void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
using Chunk = HeapSimulator::Chunk;
void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
int64 start, int64 end, const Chunk& chunk) {
void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
node_storage_.emplace_back(
BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
@ -620,8 +625,7 @@ void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
}
}
std::vector<Chunk>
GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::ChunksOverlappingInTime(
std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
int64 start, int64 end) const {
std::vector<Chunk> result;
if (node_storage_.empty()) {

View File

@ -57,6 +57,8 @@ class HeapSimulator {
int64 size;
int64 chunk_end() const { return offset + size; }
bool OverlapsWith(Chunk other_chunk) const;
};
// Result represents the result of the heap simulation.
@ -284,6 +286,39 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
int64 max_heap_size_ = 0;
};
// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
// and the chunk assigned to it.
struct BufferIntervalTreeNode {
// Alloc time.
int64 start;
// Free time.
int64 end;
// Maximum free time of all nodes in the subtree where this node is the root.
int64 subtree_end;
// Allocated chunk for the buffer.
HeapSimulator::Chunk chunk;
// Left child.
BufferIntervalTreeNode* left;
// Right child.
BufferIntervalTreeNode* right;
};
// An interval tree that can query buffers overlapping in time.
class BufferIntervalTree {
public:
using Chunk = HeapSimulator::Chunk;
// Adds a buffer to the interval tree, with the time interval and allocated
// chunk specified.
void Add(int64 start, int64 end, const Chunk& chunk);
// Returns vector of allocated chunks that overlap with the given time
// interval.
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
private:
std::list<BufferIntervalTreeNode> node_storage_;
};
// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
// then allocates them in decreasing spatial or temporal size regardless of the
// alloc/free time. It internally tracks the allocated buffers and their live
@ -334,39 +369,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
static BufferIntervalCompare GetSpatialBufferIntervalCompare();
protected:
// Node in BufferIntervalTree that stores the alloc and free times of a
// buffer, and the chunk assigned to it.
struct BufferIntervalTreeNode {
// Alloc time.
int64 start;
// Free time.
int64 end;
// Maximum free time of all nodes in the subtree where this node is the
// root.
int64 subtree_end;
// Allocated chunk for the buffer.
HeapSimulator::Chunk chunk;
// Left child.
BufferIntervalTreeNode* left;
// Right child.
BufferIntervalTreeNode* right;
};
// An interval tree that can query buffers overlapping in time.
class BufferIntervalTree {
public:
// Adds a buffer to the interval tree, with the time interval and allocated
// chunk specified.
void Add(int64 start, int64 end, const Chunk& chunk);
// Returns vector of allocated chunks that overlap with the given time
// interval.
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
private:
std::list<BufferIntervalTreeNode> node_storage_;
};
// The candidate contains a chunk and the resultant heap size if this
// chunk is to be committed.
struct ChunkCandidate {

View File

@ -1095,6 +1095,10 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) {
VLOG(1) << "Maximum number of outstanding async copies: "
<< CountMaximumOutstandingAsyncCopies(*module);
if (options.verify || VLOG_IS_ON(1)) {
TF_RETURN_IF_ERROR(memory_space_assignment.Verify());
}
return std::move(memory_space_assignment.preset_assignments_);
}
@ -1509,4 +1513,60 @@ Status MemorySpaceAssignment::FixSchedule() {
return Status::OK();
}
Status MemorySpaceAssignment::Verify() const {
VLOG(3) << "Verifying:";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module_));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
HloLiveRange::Run(module_->schedule(), *alias_analysis,
module_->entry_computation()));
BufferIntervalTree interval_tree;
absl::flat_hash_set<int64> seen_buffers;
for (const auto& position_and_chunk : preset_assignments_->chunks()) {
const HloPosition& position = position_and_chunk.first;
const Chunk& chunk = position_and_chunk.second;
const HloBuffer& buffer =
alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
if (seen_buffers.contains(buffer.id())) {
continue;
}
seen_buffers.insert(buffer.id());
int64 start_time = INT64_MAX;
int64 end_time = -1;
for (const HloValue* value : buffer.values()) {
const HloLiveRange::TimeBound& time_bound =
hlo_live_range->buffer_live_ranges().at(value);
start_time = std::min(start_time, time_bound.start);
end_time = std::max(end_time, time_bound.end);
}
CHECK_GE(start_time, 0);
CHECK_GT(end_time, 0);
// Get the chunks overlapping in time and search if they overlap in space as
// well.
// TODO(berkin): For now checking against end_time - 1 (exclusive), but we
// really should check against end_time (inclusive) for cases where the
// operand can't share buffer with user (see
// HloDataflowAnalysis::CanShareOperandBufferWithUser).
for (const Chunk& overlapping_chunk :
interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
if (chunk.OverlapsWith(overlapping_chunk)) {
return InternalError(
("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
" off: %d size: %d"),
buffer.ToString(), start_time, end_time, chunk.offset, chunk.size,
overlapping_chunk.offset, overlapping_chunk.size);
}
}
interval_tree.Add(start_time, end_time - 1, chunk);
VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", "
<< end_time << ") off: " << position_and_chunk.second.offset
<< ", size: " << position_and_chunk.second.size;
}
return Status::OK();
}
} // namespace xla

View File

@ -299,6 +299,10 @@ class MemorySpaceAssignment {
// If true, tries allocating buffers across (e.g., before and inside a while
// loop body) sequential calls (kWhile, kCall, and kConditional).
bool allocate_across_sequential_calls = false;
// If true, verifies the memory space assignment against overlapping
// buffers.
bool verify = false;
};
// This class represents an allocation that might either be in the default or
@ -471,6 +475,9 @@ class MemorySpaceAssignment {
static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis);
// Verify that the memory space assignment is free of overlapping buffers.
Status Verify() const;
private:
MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space,
const HloLiveRange& hlo_live_range)

View File

@ -107,6 +107,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
options.max_outstanding_async_copies = max_outstanding_async_copies;
options.allocate_across_sequential_calls = GetParam();
options.verify = true;
std::unique_ptr<PresetAssignments> preset_assignments =
MemorySpaceAssignment::Run(module, options).ValueOrDie();
CheckPresetAssignments(preset_assignments.get());