[XLA] Add a verifier to memory space assignment to check against overlapping buffers
PiperOrigin-RevId: 287573437 Change-Id: Ic28ee1bd4ff191c2194fb001880530eeeb4acea2
This commit is contained in:
		
							parent
							
								
									80c685a55f
								
							
						
					
					
						commit
						aa38702d75
					
				@ -31,6 +31,12 @@ namespace xla {
 | 
			
		||||
using absl::flat_hash_map;
 | 
			
		||||
using absl::flat_hash_set;
 | 
			
		||||
 | 
			
		||||
bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
 | 
			
		||||
  CHECK_NE(size, 0);
 | 
			
		||||
  CHECK_NE(other_chunk.size, 0);
 | 
			
		||||
  return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*static*/
 | 
			
		||||
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
 | 
			
		||||
    const HloSchedule& schedule,
 | 
			
		||||
@ -591,8 +597,7 @@ void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
 | 
			
		||||
 | 
			
		||||
using Chunk = HeapSimulator::Chunk;
 | 
			
		||||
 | 
			
		||||
void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
 | 
			
		||||
    int64 start, int64 end, const Chunk& chunk) {
 | 
			
		||||
void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
 | 
			
		||||
  node_storage_.emplace_back(
 | 
			
		||||
      BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
 | 
			
		||||
 | 
			
		||||
@ -620,8 +625,7 @@ void GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::Add(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<Chunk>
 | 
			
		||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalTree::ChunksOverlappingInTime(
 | 
			
		||||
std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
 | 
			
		||||
    int64 start, int64 end) const {
 | 
			
		||||
  std::vector<Chunk> result;
 | 
			
		||||
  if (node_storage_.empty()) {
 | 
			
		||||
 | 
			
		||||
@ -57,6 +57,8 @@ class HeapSimulator {
 | 
			
		||||
    int64 size;
 | 
			
		||||
 | 
			
		||||
    int64 chunk_end() const { return offset + size; }
 | 
			
		||||
 | 
			
		||||
    bool OverlapsWith(Chunk other_chunk) const;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // Result represents the result of the heap simulation.
 | 
			
		||||
@ -284,6 +286,39 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
 | 
			
		||||
  int64 max_heap_size_ = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
 | 
			
		||||
// and the chunk assigned to it.
 | 
			
		||||
struct BufferIntervalTreeNode {
 | 
			
		||||
  // Alloc time.
 | 
			
		||||
  int64 start;
 | 
			
		||||
  // Free time.
 | 
			
		||||
  int64 end;
 | 
			
		||||
  // Maximum free time of all nodes in the subtree where this node is the root.
 | 
			
		||||
  int64 subtree_end;
 | 
			
		||||
  // Allocated chunk for the buffer.
 | 
			
		||||
  HeapSimulator::Chunk chunk;
 | 
			
		||||
  // Left child.
 | 
			
		||||
  BufferIntervalTreeNode* left;
 | 
			
		||||
  // Right child.
 | 
			
		||||
  BufferIntervalTreeNode* right;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// An interval tree that can query buffers overlapping in time.
 | 
			
		||||
class BufferIntervalTree {
 | 
			
		||||
 public:
 | 
			
		||||
  using Chunk = HeapSimulator::Chunk;
 | 
			
		||||
  // Adds a buffer to the interval tree, with the time interval and allocated
 | 
			
		||||
  // chunk specified.
 | 
			
		||||
  void Add(int64 start, int64 end, const Chunk& chunk);
 | 
			
		||||
 | 
			
		||||
  // Returns vector of allocated chunks that overlap with the given time
 | 
			
		||||
  // interval.
 | 
			
		||||
  std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::list<BufferIntervalTreeNode> node_storage_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
 | 
			
		||||
// then allocates them in decreasing spatial or temporal size regardless of the
 | 
			
		||||
// alloc/free time. It internally tracks the allocated buffers and their live
 | 
			
		||||
@ -334,39 +369,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
 | 
			
		||||
  static BufferIntervalCompare GetSpatialBufferIntervalCompare();
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Node in BufferIntervalTree that stores the alloc and free times of a
 | 
			
		||||
  // buffer, and the chunk assigned to it.
 | 
			
		||||
  struct BufferIntervalTreeNode {
 | 
			
		||||
    // Alloc time.
 | 
			
		||||
    int64 start;
 | 
			
		||||
    // Free time.
 | 
			
		||||
    int64 end;
 | 
			
		||||
    // Maximum free time of all nodes in the subtree where this node is the
 | 
			
		||||
    // root.
 | 
			
		||||
    int64 subtree_end;
 | 
			
		||||
    // Allocated chunk for the buffer.
 | 
			
		||||
    HeapSimulator::Chunk chunk;
 | 
			
		||||
    // Left child.
 | 
			
		||||
    BufferIntervalTreeNode* left;
 | 
			
		||||
    // Right child.
 | 
			
		||||
    BufferIntervalTreeNode* right;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // An interval tree that can query buffers overlapping in time.
 | 
			
		||||
  class BufferIntervalTree {
 | 
			
		||||
   public:
 | 
			
		||||
    // Adds a buffer to the interval tree, with the time interval and allocated
 | 
			
		||||
    // chunk specified.
 | 
			
		||||
    void Add(int64 start, int64 end, const Chunk& chunk);
 | 
			
		||||
 | 
			
		||||
    // Returns vector of allocated chunks that overlap with the given time
 | 
			
		||||
    // interval.
 | 
			
		||||
    std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) const;
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
    std::list<BufferIntervalTreeNode> node_storage_;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // The candidate contains a chunk and the resultant heap size if this
 | 
			
		||||
  // chunk is to be committed.
 | 
			
		||||
  struct ChunkCandidate {
 | 
			
		||||
 | 
			
		||||
@ -1095,6 +1095,10 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) {
 | 
			
		||||
  VLOG(1) << "Maximum number of outstanding async copies: "
 | 
			
		||||
          << CountMaximumOutstandingAsyncCopies(*module);
 | 
			
		||||
 | 
			
		||||
  if (options.verify || VLOG_IS_ON(1)) {
 | 
			
		||||
    TF_RETURN_IF_ERROR(memory_space_assignment.Verify());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return std::move(memory_space_assignment.preset_assignments_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1509,4 +1513,60 @@ Status MemorySpaceAssignment::FixSchedule() {
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status MemorySpaceAssignment::Verify() const {
 | 
			
		||||
  VLOG(3) << "Verifying:";
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
 | 
			
		||||
                      HloAliasAnalysis::Run(module_));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
 | 
			
		||||
                      HloLiveRange::Run(module_->schedule(), *alias_analysis,
 | 
			
		||||
                                        module_->entry_computation()));
 | 
			
		||||
 | 
			
		||||
  BufferIntervalTree interval_tree;
 | 
			
		||||
  absl::flat_hash_set<int64> seen_buffers;
 | 
			
		||||
 | 
			
		||||
  for (const auto& position_and_chunk : preset_assignments_->chunks()) {
 | 
			
		||||
    const HloPosition& position = position_and_chunk.first;
 | 
			
		||||
    const Chunk& chunk = position_and_chunk.second;
 | 
			
		||||
    const HloBuffer& buffer =
 | 
			
		||||
        alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
 | 
			
		||||
    if (seen_buffers.contains(buffer.id())) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    seen_buffers.insert(buffer.id());
 | 
			
		||||
 | 
			
		||||
    int64 start_time = INT64_MAX;
 | 
			
		||||
    int64 end_time = -1;
 | 
			
		||||
    for (const HloValue* value : buffer.values()) {
 | 
			
		||||
      const HloLiveRange::TimeBound& time_bound =
 | 
			
		||||
          hlo_live_range->buffer_live_ranges().at(value);
 | 
			
		||||
      start_time = std::min(start_time, time_bound.start);
 | 
			
		||||
      end_time = std::max(end_time, time_bound.end);
 | 
			
		||||
    }
 | 
			
		||||
    CHECK_GE(start_time, 0);
 | 
			
		||||
    CHECK_GT(end_time, 0);
 | 
			
		||||
    // Get the chunks overlapping in time and search if they overlap in space as
 | 
			
		||||
    // well.
 | 
			
		||||
    // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
 | 
			
		||||
    // really should check against end_time (inclusive) for cases where the
 | 
			
		||||
    // operand can't share buffer with user (see
 | 
			
		||||
    // HloDataflowAnalysis::CanShareOperandBufferWithUser).
 | 
			
		||||
    for (const Chunk& overlapping_chunk :
 | 
			
		||||
         interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
 | 
			
		||||
      if (chunk.OverlapsWith(overlapping_chunk)) {
 | 
			
		||||
        return InternalError(
 | 
			
		||||
            ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
 | 
			
		||||
             " off: %d size: %d"),
 | 
			
		||||
            buffer.ToString(), start_time, end_time, chunk.offset, chunk.size,
 | 
			
		||||
            overlapping_chunk.offset, overlapping_chunk.size);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    interval_tree.Add(start_time, end_time - 1, chunk);
 | 
			
		||||
    VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", "
 | 
			
		||||
            << end_time << ") off: " << position_and_chunk.second.offset
 | 
			
		||||
            << ", size: " << position_and_chunk.second.size;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -299,6 +299,10 @@ class MemorySpaceAssignment {
 | 
			
		||||
    // If true, tries allocating buffers across (e.g., before and inside a while
 | 
			
		||||
    // loop body) sequential calls (kWhile, kCall, and kConditional).
 | 
			
		||||
    bool allocate_across_sequential_calls = false;
 | 
			
		||||
 | 
			
		||||
    // If true, verifies the memory space assignment against overlapping
 | 
			
		||||
    // buffers.
 | 
			
		||||
    bool verify = false;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // This class represents an allocation that might either be in the default or
 | 
			
		||||
@ -471,6 +475,9 @@ class MemorySpaceAssignment {
 | 
			
		||||
  static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
 | 
			
		||||
      const MemorySpaceAssignmentCostAnalysis& cost_analysis);
 | 
			
		||||
 | 
			
		||||
  // Verify that the memory space assignment is free of overlapping buffers.
 | 
			
		||||
  Status Verify() const;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space,
 | 
			
		||||
                        const HloLiveRange& hlo_live_range)
 | 
			
		||||
 | 
			
		||||
@ -107,6 +107,7 @@ class MemorySpaceAssignmentTest : public HloTestBase,
 | 
			
		||||
    options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
 | 
			
		||||
    options.max_outstanding_async_copies = max_outstanding_async_copies;
 | 
			
		||||
    options.allocate_across_sequential_calls = GetParam();
 | 
			
		||||
    options.verify = true;
 | 
			
		||||
    std::unique_ptr<PresetAssignments> preset_assignments =
 | 
			
		||||
        MemorySpaceAssignment::Run(module, options).ValueOrDie();
 | 
			
		||||
    CheckPresetAssignments(preset_assignments.get());
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user