[XLA] Use templates in heap simulator to allow opaque type to be different than HloValue (NFC)
This CL allows reusing the heap algorithm machinery for opaque types other than HloValue. This is in preparation for using heap algorithms as memory space assignment repackers to reduce fragmentation of the alternate memory. PiperOrigin-RevId: 326745711 Change-Id: I30845956ee22a1958eb7ea39a9653f1cefa7691b
This commit is contained in:
parent
73b40908a4
commit
ba58b8cafa
tensorflow/compiler/xla/service
@ -1431,6 +1431,7 @@ cc_library(
|
||||
":hlo_live_range",
|
||||
":hlo_ordering",
|
||||
":hlo_proto_cc",
|
||||
":memory_space_assignment_repacking",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -1424,13 +1424,16 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
// Returns a heap algorithm that chooses the best result from several
|
||||
// algorithms.
|
||||
auto get_heap_algorithm = [&](int64 alignment) {
|
||||
auto algorithms =
|
||||
absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
|
||||
algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
|
||||
alignment, GlobalDecreasingSizeBestFitHeap::kSpatial));
|
||||
algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
|
||||
alignment, GlobalDecreasingSizeBestFitHeap::kTemporal));
|
||||
return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
|
||||
auto algorithms = absl::make_unique<
|
||||
std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
|
||||
algorithms->push_back(
|
||||
absl::make_unique<GlobalDecreasingSizeBestFitHeap<HloValue>>(
|
||||
alignment, GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
|
||||
algorithms->push_back(
|
||||
absl::make_unique<GlobalDecreasingSizeBestFitHeap<HloValue>>(
|
||||
alignment, GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
|
||||
return absl::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
|
||||
std::move(algorithms));
|
||||
};
|
||||
|
||||
if (run_whole_module_heap_simulation) {
|
||||
@ -1461,7 +1464,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
options.buffers_to_assign = &single_colored_set.second;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Result<HloValue> result,
|
||||
HeapSimulator::Run(
|
||||
get_heap_algorithm(alignment), assignment->module(), schedule,
|
||||
assignment->alias_analysis(), assignment->buffer_size_, options));
|
||||
@ -1487,7 +1490,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
HeapSimulator::Options options;
|
||||
options.buffers_to_assign = &single_colored_set.second;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Result<HloValue> result,
|
||||
HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
|
||||
*instruction_sequence,
|
||||
assignment->alias_analysis(),
|
||||
@ -1582,7 +1585,7 @@ std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
|
||||
} // namespace
|
||||
|
||||
void BufferAssigner::AssignBuffersFromHeapSimulator(
|
||||
const HeapSimulator::Result& result, BufferAssignment* assignment,
|
||||
const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
|
||||
BufferValue::Color color) {
|
||||
if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
|
||||
assignment->stats_.preallocated_temp_fragmentation_bytes =
|
||||
|
@ -661,9 +661,9 @@ class BufferAssigner {
|
||||
|
||||
// Uses the results of the heap simulator to create a single allocation, with
|
||||
// LogicalBuffers packed to specific offsets.
|
||||
void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
|
||||
BufferAssignment* assignment,
|
||||
LogicalBuffer::Color color);
|
||||
void AssignBuffersFromHeapSimulator(
|
||||
const HeapSimulator::Result<HloValue>& result,
|
||||
BufferAssignment* assignment, LogicalBuffer::Color color);
|
||||
|
||||
// Tries to assign the given instruction to the given buffer. Returns if the
|
||||
// assignment was successful.
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_live_range.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -55,9 +56,10 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
|
||||
// rather than summing each computation, since it gives us a better lower
|
||||
// bound, by minimizing the liveness of sub-computations.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
|
||||
schedule, *alias_analysis, size_function));
|
||||
HeapSimulator::Result<HloValue> result,
|
||||
HeapSimulator::Run(
|
||||
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), *module,
|
||||
schedule, *alias_analysis, size_function));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
@ -69,10 +71,11 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
|
||||
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||
memory_by_computation) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
|
||||
computation, sequence, alias_analysis, size_function,
|
||||
HeapSimulator::Options(), memory_by_computation));
|
||||
HeapSimulator::Result<HloValue> result,
|
||||
HeapSimulator::Run(
|
||||
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
|
||||
sequence, alias_analysis, size_function, HeapSimulator::Options(),
|
||||
memory_by_computation));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
@ -82,16 +85,17 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const HloSchedule* schedule) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
|
||||
computation, sequence, alias_analysis, size_function,
|
||||
schedule, HeapSimulator::Options()));
|
||||
HeapSimulator::Result<HloValue> result,
|
||||
HeapSimulator::Run(
|
||||
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
|
||||
sequence, alias_analysis, size_function, schedule,
|
||||
HeapSimulator::Options()));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
/*static*/
|
||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
|
||||
StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
|
||||
const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
|
||||
const BufferValue::SizeFunction& size_fn, const Options& options) {
|
||||
HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
|
||||
@ -108,8 +112,9 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
}
|
||||
|
||||
/*static*/
|
||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
|
||||
StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const HloComputation& computation,
|
||||
const HloInstructionSequence& instruction_sequence,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const BufferValue::SizeFunction& size_fn, const Options& options,
|
||||
@ -128,8 +133,9 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
}
|
||||
|
||||
/*static*/
|
||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
|
||||
StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const HloComputation& computation,
|
||||
const HloInstructionSequence& instruction_sequence,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
|
||||
@ -326,12 +332,13 @@ Status HeapSimulator::RunComputation(
|
||||
}
|
||||
|
||||
HeapSimulator::HeapSimulator(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const BufferValue::SizeFunction& size_fn, const Options& options,
|
||||
const HloSchedule* schedule,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||
memory_by_computation)
|
||||
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
|
||||
: no_fragmentation_stats_(
|
||||
absl::make_unique<NoFragmentationStatsHeap<HloValue>>()),
|
||||
algorithm_(std::move(algorithm)),
|
||||
size_fn_(size_fn),
|
||||
options_(options),
|
||||
@ -396,8 +403,8 @@ void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
|
||||
shared);
|
||||
}
|
||||
|
||||
HeapSimulator::Result HeapSimulator::Finish() {
|
||||
Result result = algorithm_->Finish();
|
||||
HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
|
||||
Result<HloValue> result = algorithm_->Finish();
|
||||
|
||||
// Post-process the result to add chunks for shared buffers. An empty chunk
|
||||
// map means that either no buffers were allocated, or the heap was only
|
||||
@ -411,7 +418,7 @@ HeapSimulator::Result HeapSimulator::Finish() {
|
||||
}
|
||||
|
||||
// Fragmentation is the difference between the actual and ideal sizes.
|
||||
const Result no_frag_result = no_fragmentation_stats_->Finish();
|
||||
const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
|
||||
result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
|
||||
|
||||
// Copy the debug trace we collected to the final result.
|
||||
@ -437,14 +444,17 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
|
||||
}
|
||||
}
|
||||
|
||||
void NoFragmentationStatsHeap::Alloc(const HloValue* buffer, int64 size) {
|
||||
template <typename BufferType>
|
||||
void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
|
||||
int64 size) {
|
||||
current_heap_size_ += size;
|
||||
if (current_heap_size_ > max_heap_size_) {
|
||||
max_heap_size_ = current_heap_size_;
|
||||
}
|
||||
}
|
||||
|
||||
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
||||
template <typename BufferType>
|
||||
void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
|
||||
const HloInstruction* instruction, int64 alloc_size_by_instruction,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
memory_by_computation) {
|
||||
@ -472,11 +482,15 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
||||
std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
|
||||
}
|
||||
|
||||
void NoFragmentationStatsHeap::Free(const HloValue* buffer, int64 size) {
|
||||
template <typename BufferType>
|
||||
void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
|
||||
int64 size) {
|
||||
current_heap_size_ -= size;
|
||||
}
|
||||
|
||||
HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
|
||||
template <typename BufferType>
|
||||
HeapSimulator::Result<BufferType>
|
||||
NoFragmentationStatsHeap<BufferType>::Finish() {
|
||||
// The result.chunk_map is empty, since we only collect stats, and don't
|
||||
// actually compute chunk assignments.
|
||||
Result result;
|
||||
@ -484,7 +498,8 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
|
||||
return result;
|
||||
}
|
||||
|
||||
GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
|
||||
template <typename BufferType>
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
|
||||
int64 alignment, Type type)
|
||||
: alignment_(alignment) {
|
||||
if (type == kTemporal) {
|
||||
@ -495,8 +510,10 @@ GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
|
||||
}
|
||||
}
|
||||
|
||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
|
||||
GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
|
||||
template <typename BufferType>
|
||||
typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
|
||||
const {
|
||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||
int64 x_end = x.end;
|
||||
for (auto colocation : GetTransitiveColocations(x)) {
|
||||
@ -515,12 +532,14 @@ GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
|
||||
if (x.size != y.size) {
|
||||
return x.size > y.size;
|
||||
}
|
||||
return x.buffer->id() < y.buffer->id();
|
||||
return *x.buffer < *y.buffer;
|
||||
};
|
||||
}
|
||||
|
||||
/*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
|
||||
GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
|
||||
template <typename BufferType>
|
||||
/*static*/ typename GlobalDecreasingSizeBestFitHeap<
|
||||
BufferType>::BufferIntervalCompare
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
|
||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||
if (x.size != y.size) {
|
||||
return x.size > y.size;
|
||||
@ -528,12 +547,13 @@ GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
|
||||
if (x.end - x.start != y.end - y.start) {
|
||||
return x.end - x.start > y.end - y.start;
|
||||
}
|
||||
return x.buffer->id() < y.buffer->id();
|
||||
return *x.buffer < *y.buffer;
|
||||
};
|
||||
}
|
||||
|
||||
void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
|
||||
int64 size) {
|
||||
template <typename BufferType>
|
||||
void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
|
||||
const BufferType* buffer, int64 size) {
|
||||
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
||||
if (size == 0) {
|
||||
result_.chunk_map.emplace(buffer, Chunk{0, 0});
|
||||
@ -546,9 +566,9 @@ void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
|
||||
++current_time_;
|
||||
}
|
||||
|
||||
void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer,
|
||||
const HloValue* share_with,
|
||||
int64 size) {
|
||||
template <typename BufferType>
|
||||
void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
|
||||
const BufferType* buffer, const BufferType* share_with, int64 size) {
|
||||
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
||||
if (size == 0) {
|
||||
result_.chunk_map.emplace(buffer, Chunk{0, 0});
|
||||
@ -562,15 +582,16 @@ void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer,
|
||||
++current_time_;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const HloValue*>
|
||||
GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations(
|
||||
template <typename BufferType>
|
||||
absl::flat_hash_set<const BufferType*>
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
|
||||
const BufferInterval& interval) const {
|
||||
absl::flat_hash_set<const HloValue*> result;
|
||||
absl::flat_hash_set<const BufferType*> result;
|
||||
std::vector<const BufferInterval*> worklist = {&interval};
|
||||
while (!worklist.empty()) {
|
||||
const BufferInterval* item = worklist.back();
|
||||
worklist.pop_back();
|
||||
for (const HloValue* buffer_colocated : item->colocations) {
|
||||
for (const BufferType* buffer_colocated : item->colocations) {
|
||||
result.insert(buffer_colocated);
|
||||
worklist.push_back(&buffer_intervals_.at(buffer_colocated));
|
||||
}
|
||||
@ -579,7 +600,9 @@ GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations(
|
||||
return result;
|
||||
}
|
||||
|
||||
void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
|
||||
template <typename BufferType>
|
||||
void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
|
||||
int64 size) {
|
||||
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
||||
if (size == 0) {
|
||||
return;
|
||||
@ -785,7 +808,9 @@ std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
|
||||
return result;
|
||||
}
|
||||
|
||||
HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
||||
template <typename BufferType>
|
||||
HeapSimulator::Result<BufferType>
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
|
||||
std::vector<BufferInterval> sorted_buffer_intervals =
|
||||
GetSortedBufferIntervals();
|
||||
|
||||
@ -803,8 +828,10 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
||||
return result_;
|
||||
}
|
||||
|
||||
std::vector<GlobalDecreasingSizeBestFitHeap::BufferInterval>
|
||||
GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const {
|
||||
template <typename BufferType>
|
||||
std::vector<
|
||||
typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
|
||||
std::vector<BufferInterval> sorted_buffer_intervals;
|
||||
for (auto& entry : buffer_intervals_) {
|
||||
sorted_buffer_intervals.push_back(entry.second);
|
||||
@ -814,8 +841,9 @@ GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const {
|
||||
return sorted_buffer_intervals;
|
||||
}
|
||||
|
||||
GlobalDecreasingSizeBestFitHeap::ChunkCandidate
|
||||
GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
|
||||
template <typename BufferType>
|
||||
typename GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
|
||||
int64 preferred_offset) const {
|
||||
VLOG(1) << "Finding chunks for buffer: "
|
||||
@ -912,9 +940,12 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
|
||||
return chunk_candidate;
|
||||
}
|
||||
|
||||
void GlobalDecreasingSizeBestFitHeap::CommitChunk(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
|
||||
GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate) {
|
||||
template <typename BufferType>
|
||||
void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
|
||||
const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
|
||||
buffer_interval,
|
||||
GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
|
||||
chunk_candidate) {
|
||||
// Update the maximum heap size according to the one determined by the chunk
|
||||
// candidate.
|
||||
result_.heap_size = chunk_candidate.heap_size;
|
||||
@ -930,13 +961,16 @@ void GlobalDecreasingSizeBestFitHeap::CommitChunk(
|
||||
AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
|
||||
}
|
||||
|
||||
void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer,
|
||||
Chunk chunk) {
|
||||
template <typename BufferType>
|
||||
void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
|
||||
const BufferType* buffer, Chunk chunk) {
|
||||
const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
|
||||
DCHECK(emplace_result.second);
|
||||
}
|
||||
|
||||
HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
|
||||
template <typename BufferType>
|
||||
HeapSimulator::Result<BufferType>
|
||||
ChooseBestHeapAlgorithm<BufferType>::Finish() {
|
||||
DCHECK(!algorithms_.empty());
|
||||
std::vector<Result> results(algorithms_.size());
|
||||
int64 min_size = INT64_MAX;
|
||||
@ -953,4 +987,9 @@ HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
|
||||
return results[min_size_index];
|
||||
}
|
||||
|
||||
template class GlobalDecreasingSizeBestFitHeap<HloValue>;
|
||||
template class GlobalDecreasingSizeBestFitHeap<
|
||||
MemorySpaceAssignmentRepacker::AllocationBlock>;
|
||||
template class ChooseBestHeapAlgorithm<HloValue>;
|
||||
|
||||
} // namespace xla
|
||||
|
@ -40,7 +40,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// Forward declare classes defined below.
|
||||
template <typename BufferType>
|
||||
class HeapAlgorithm;
|
||||
template <typename BufferType>
|
||||
class NoFragmentationStatsHeap;
|
||||
|
||||
// HeapSimulator assigns buffer offsets by running a simulation of a regular
|
||||
@ -66,9 +68,10 @@ class HeapSimulator {
|
||||
};
|
||||
|
||||
// Result represents the result of the heap simulation.
|
||||
template <typename BufferType>
|
||||
struct Result {
|
||||
// The assignment of buffers to chunks.
|
||||
absl::flat_hash_map<const HloValue*, Chunk> chunk_map;
|
||||
absl::flat_hash_map<const BufferType*, Chunk> chunk_map;
|
||||
|
||||
// The total size in bytes of the heap, containing all assigned chunks.
|
||||
int64 heap_size = 0;
|
||||
@ -128,19 +131,19 @@ class HeapSimulator {
|
||||
// to running on a per-computation basis, since we can re-use buffer space for
|
||||
// called sub-computations.
|
||||
//
|
||||
static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
const HloModule& module,
|
||||
const HloSchedule& schedule,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const BufferValue::SizeFunction& size_fn,
|
||||
const Options& options = Options());
|
||||
static StatusOr<Result<HloValue>> Run(
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const HloModule& module, const HloSchedule& schedule,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const BufferValue::SizeFunction& size_fn,
|
||||
const Options& options = Options());
|
||||
|
||||
// Same as above, but runs on a single computation. The 'instruction_sequence'
|
||||
// must contain a topologically-consistent total ordering of all instructions
|
||||
// in the computation. The result is invalid if instructions are not run in
|
||||
// exactly this sequence.
|
||||
static StatusOr<Result> Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
static StatusOr<Result<HloValue>> Run(
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const HloComputation& computation,
|
||||
const HloInstructionSequence& instruction_sequence,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
@ -151,8 +154,8 @@ class HeapSimulator {
|
||||
|
||||
// Same as above, but runs on with a schedule that covers all nested
|
||||
// computations.
|
||||
static StatusOr<Result> Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
static StatusOr<Result<HloValue>> Run(
|
||||
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const HloComputation& computation,
|
||||
const HloInstructionSequence& instruction_sequence,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
@ -163,7 +166,7 @@ class HeapSimulator {
|
||||
// If 'schedule' is non-null, it is used to find kCall and kWhile
|
||||
// sub-computations, and the heap simulation for those sub-computations will
|
||||
// be run recursively. I.e. the simulation is run over the whole module.
|
||||
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||
const BufferValue::SizeFunction& size_fn,
|
||||
const Options& options, const HloSchedule* schedule = nullptr,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||
@ -187,7 +190,7 @@ class HeapSimulator {
|
||||
// Two buffers belong to the same shared group.
|
||||
// Eight of the buffer has no shared group assigned.
|
||||
bool InSameSharedGroup(const HloValue* left, const HloValue* right);
|
||||
Result Finish();
|
||||
Result<HloValue> Finish();
|
||||
|
||||
void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
|
||||
const HloValue* buffer, const HloInstruction* instruction,
|
||||
@ -196,8 +199,9 @@ class HeapSimulator {
|
||||
// Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
|
||||
// in which case we are calculating the same allocs/frees twice in the
|
||||
// simulation.
|
||||
const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_;
|
||||
const std::unique_ptr<HeapAlgorithm> algorithm_;
|
||||
const std::unique_ptr<NoFragmentationStatsHeap<HloValue>>
|
||||
no_fragmentation_stats_;
|
||||
const std::unique_ptr<HeapAlgorithm<HloValue>> algorithm_;
|
||||
const BufferValue::SizeFunction size_fn_;
|
||||
const Options options_;
|
||||
// schedule_ is set by buffer assignment, and memory_by_computation_ is
|
||||
@ -220,15 +224,16 @@ class HeapSimulator {
|
||||
// offsets to buffers. A sequence of Alloc / Free calls will be made, with the
|
||||
// same semantics as a regular memory heap. Finish will be called at the end to
|
||||
// collect the simulation results.
|
||||
template <typename BufferType>
|
||||
class HeapAlgorithm {
|
||||
public:
|
||||
using Chunk = HeapSimulator::Chunk;
|
||||
using Result = HeapSimulator::Result;
|
||||
using Result = HeapSimulator::Result<BufferType>;
|
||||
|
||||
virtual ~HeapAlgorithm() = default;
|
||||
|
||||
// Alloc allocates a buffer of 'size' bytes.
|
||||
virtual void Alloc(const HloValue* buffer, int64 size) = 0;
|
||||
virtual void Alloc(const BufferType* buffer, int64 size) = 0;
|
||||
|
||||
// Takes memory usage of subcomputations into account when calculating the
|
||||
// memory usage of a computation. Currently, we don't handle buffer aliasing
|
||||
@ -247,7 +252,7 @@ class HeapAlgorithm {
|
||||
memory_by_computation) {}
|
||||
|
||||
// Free de-allocates a previously allocated buffer.
|
||||
virtual void Free(const HloValue* buffer, int64 size) = 0;
|
||||
virtual void Free(const BufferType* buffer, int64 size) = 0;
|
||||
|
||||
// Indicates that a buffer has to be collocated with another buffer. In
|
||||
// addition to Alloc and Free, the heap simulator exposes a concept of buffer
|
||||
@ -255,7 +260,7 @@ class HeapAlgorithm {
|
||||
// the buffer, it associates the buffer with a previously allocated (or
|
||||
// shared) buffer. Each group of mutually-shared buffers points to a single
|
||||
// SharedGroup instance, which is a shared control block.
|
||||
virtual void ShareWith(const HloValue* buffer, const HloValue* share_with,
|
||||
virtual void ShareWith(const BufferType* buffer, const BufferType* share_with,
|
||||
int64 size) {
|
||||
Alloc(buffer, size);
|
||||
}
|
||||
@ -269,19 +274,22 @@ class HeapAlgorithm {
|
||||
// this is the absolute minimum size for a given instruction sequence. The
|
||||
// result.chunk_map returned in Finish is always empty, since we only collect
|
||||
// stats, and don't actually compute chunk assignments.
|
||||
class NoFragmentationStatsHeap : public HeapAlgorithm {
|
||||
template <typename BufferType>
|
||||
class NoFragmentationStatsHeap : public HeapAlgorithm<BufferType> {
|
||||
public:
|
||||
using Result = HeapSimulator::Result<BufferType>;
|
||||
|
||||
NoFragmentationStatsHeap() = default;
|
||||
~NoFragmentationStatsHeap() override = default;
|
||||
|
||||
void Alloc(const HloValue* buffer, int64 size) override;
|
||||
void Alloc(const BufferType* buffer, int64 size) override;
|
||||
|
||||
void AccountForSubcomputationMemory(
|
||||
const HloInstruction* instruction, int64 alloc_size_by_instruction,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
memory_by_computation) override;
|
||||
|
||||
void Free(const HloValue* buffer, int64 size) override;
|
||||
void Free(const BufferType* buffer, int64 size) override;
|
||||
|
||||
Result Finish() override;
|
||||
|
||||
@ -336,8 +344,12 @@ class BufferIntervalTree {
|
||||
// alloc/free time. It internally tracks the allocated buffers and their live
|
||||
// intervals; when allocating a buffer, it finds the best-fit free chunk during
|
||||
// its live interval.
|
||||
class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
template <typename BufferType>
|
||||
class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm<BufferType> {
|
||||
public:
|
||||
using Result = HeapSimulator::Result<BufferType>;
|
||||
using Chunk = HeapSimulator::Chunk;
|
||||
|
||||
enum Type {
|
||||
kSpatial = 0,
|
||||
kTemporal,
|
||||
@ -345,7 +357,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
|
||||
// BufferInterval stores a buffer's size and time interval.
|
||||
struct BufferInterval {
|
||||
const HloValue* buffer;
|
||||
const BufferType* buffer;
|
||||
int64 size;
|
||||
// Alloc time of the buffer.
|
||||
int64 start;
|
||||
@ -353,7 +365,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
int64 end;
|
||||
|
||||
// Colocation buffers that need to be collocated with this one.
|
||||
std::vector<const HloValue*> colocations;
|
||||
std::vector<const BufferType*> colocations;
|
||||
|
||||
// True if this buffer needs an allocation. False if it is collocated with
|
||||
// other buffer.
|
||||
@ -368,10 +380,10 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
Type type = kSpatial);
|
||||
~GlobalDecreasingSizeBestFitHeap() override {}
|
||||
|
||||
void Alloc(const HloValue* buffer, int64 size) override;
|
||||
void Free(const HloValue* buffer, int64 size) override;
|
||||
void Alloc(const BufferType* buffer, int64 size) override;
|
||||
void Free(const BufferType* buffer, int64 size) override;
|
||||
|
||||
void ShareWith(const HloValue* buffer, const HloValue* share_with,
|
||||
void ShareWith(const BufferType* buffer, const BufferType* share_with,
|
||||
int64 size) override;
|
||||
|
||||
Result Finish() override;
|
||||
@ -404,7 +416,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
void CommitChunk(const BufferInterval& buffer_interval,
|
||||
ChunkCandidate chunk_candidate);
|
||||
// Adds the buffer and the chunk to the result chunk map.
|
||||
virtual void AddToChunkMap(const HloValue* buffer, Chunk chunk);
|
||||
virtual void AddToChunkMap(const BufferType* buffer, Chunk chunk);
|
||||
|
||||
// Return a BufferIntervalCompare function that sorts by live ranges. A live
|
||||
// range is defined by the range between the start of the first buffer and the
|
||||
@ -413,7 +425,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
// contiguous.
|
||||
BufferIntervalCompare GetTemporalBufferIntervalCompare() const;
|
||||
|
||||
absl::flat_hash_map<const HloValue*, BufferInterval> buffer_intervals_;
|
||||
absl::flat_hash_map<const BufferType*, BufferInterval> buffer_intervals_;
|
||||
Result result_;
|
||||
BufferIntervalCompare buffer_interval_compare_;
|
||||
BufferIntervalTree interval_tree_;
|
||||
@ -428,33 +440,37 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
// Returns all transitive colocated buffers of this buffer interval. I.e., If
|
||||
// a buffer A is colocated with B and B is colocated with C, this function
|
||||
// returns all three of them.
|
||||
absl::flat_hash_set<const HloValue*> GetTransitiveColocations(
|
||||
absl::flat_hash_set<const BufferType*> GetTransitiveColocations(
|
||||
const BufferInterval& interval) const;
|
||||
};
|
||||
|
||||
// A heap algorithm that chooses the best results from other algorithms added to
|
||||
// it.
|
||||
class ChooseBestHeapAlgorithm : public HeapAlgorithm {
|
||||
template <typename BufferType>
|
||||
class ChooseBestHeapAlgorithm : public HeapAlgorithm<BufferType> {
|
||||
public:
|
||||
using Result = HeapSimulator::Result<BufferType>;
|
||||
|
||||
ChooseBestHeapAlgorithm(
|
||||
std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm>>> algorithms)
|
||||
std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm<BufferType>>>>
|
||||
algorithms)
|
||||
: algorithms_(std::move(*algorithms)) {}
|
||||
~ChooseBestHeapAlgorithm() override {}
|
||||
|
||||
void Alloc(const HloValue* buffer, int64 size) override {
|
||||
void Alloc(const BufferType* buffer, int64 size) override {
|
||||
for (auto& algorithm : algorithms_) {
|
||||
algorithm->Alloc(buffer, size);
|
||||
}
|
||||
}
|
||||
|
||||
void ShareWith(const HloValue* buffer, const HloValue* share_with,
|
||||
void ShareWith(const BufferType* buffer, const BufferType* share_with,
|
||||
int64 size) override {
|
||||
for (auto& algorithm : algorithms_) {
|
||||
algorithm->ShareWith(buffer, share_with, size);
|
||||
}
|
||||
}
|
||||
|
||||
void Free(const HloValue* buffer, int64 size) override {
|
||||
void Free(const BufferType* buffer, int64 size) override {
|
||||
for (auto& algorithm : algorithms_) {
|
||||
algorithm->Free(buffer, size);
|
||||
}
|
||||
@ -463,7 +479,7 @@ class ChooseBestHeapAlgorithm : public HeapAlgorithm {
|
||||
Result Finish() override;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<HeapAlgorithm>> algorithms_;
|
||||
std::vector<std::unique_ptr<HeapAlgorithm<BufferType>>> algorithms_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -228,7 +228,7 @@ const char kFinish[] = "Finish";
|
||||
using CallSequence = std::vector<std::pair<string, const HloValue*>>;
|
||||
|
||||
// HeapCallRecorder is a dummy heap algorithm that simply records its calls.
|
||||
class HeapCallRecorder : public HeapAlgorithm {
|
||||
class HeapCallRecorder : public HeapAlgorithm<HloValue> {
|
||||
public:
|
||||
explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
|
||||
~HeapCallRecorder() override {}
|
||||
@ -396,7 +396,7 @@ class HeapSimulatorTracker {
|
||||
std::unique_ptr<HloModule> module_;
|
||||
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
|
||||
CallSequence actual_calls_;
|
||||
HeapSimulator::Result result_;
|
||||
HeapSimulator::Result<HloValue> result_;
|
||||
};
|
||||
|
||||
class HeapSimulatorTest : public HloTestBase {
|
||||
@ -976,12 +976,12 @@ class HeapAlgorithmTestBase : public ::testing::Test {
|
||||
class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
|
||||
|
||||
TEST_F(NoFragmentationStatsHeapTest, Empty) {
|
||||
NoFragmentationStatsHeap heap;
|
||||
NoFragmentationStatsHeap<HloValue> heap;
|
||||
EXPECT_EQ(0, heap.Finish().heap_size);
|
||||
}
|
||||
|
||||
TEST_F(NoFragmentationStatsHeapTest, Simple) {
|
||||
NoFragmentationStatsHeap heap;
|
||||
NoFragmentationStatsHeap<HloValue> heap;
|
||||
heap.Alloc(buffer_a_, 10);
|
||||
heap.Alloc(buffer_b_, 20);
|
||||
heap.Alloc(buffer_c_, 30);
|
||||
@ -994,7 +994,7 @@ TEST_F(NoFragmentationStatsHeapTest, Simple) {
|
||||
}
|
||||
|
||||
TEST_F(NoFragmentationStatsHeapTest, Mixed) {
|
||||
NoFragmentationStatsHeap heap;
|
||||
NoFragmentationStatsHeap<HloValue> heap;
|
||||
heap.Alloc(buffer_a_, 10); // max: A
|
||||
|
||||
heap.Alloc(buffer_b_, 20); // max: A+B
|
||||
@ -1013,7 +1013,7 @@ TEST_F(NoFragmentationStatsHeapTest, Mixed) {
|
||||
class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
|
||||
protected:
|
||||
class InheritedGlobalDecreasingSizeBestFitHeap
|
||||
: public GlobalDecreasingSizeBestFitHeap {
|
||||
: public GlobalDecreasingSizeBestFitHeap<HloValue> {
|
||||
public:
|
||||
InheritedGlobalDecreasingSizeBestFitHeap()
|
||||
: GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
|
||||
@ -1048,8 +1048,8 @@ class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
|
||||
};
|
||||
|
||||
TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(0, result.heap_size);
|
||||
EXPECT_EQ(0, result.chunk_map.size());
|
||||
}
|
||||
@ -1068,7 +1068,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
|
||||
// | | d |
|
||||
// | +-------+
|
||||
// -----------------> time
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||
heap.Alloc(buffer_a_, 10);
|
||||
heap.Alloc(buffer_b_, 30);
|
||||
heap.Alloc(buffer_c_, 20);
|
||||
@ -1078,7 +1078,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
|
||||
heap.Free(buffer_c_, 20);
|
||||
heap.Free(buffer_d_, 40);
|
||||
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(100, result.heap_size);
|
||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||
EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
|
||||
@ -1107,7 +1107,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
|
||||
// | | |
|
||||
// | +-------+
|
||||
// ---------------------> time
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/20);
|
||||
heap.Alloc(buffer_a_, 10);
|
||||
heap.Alloc(buffer_b_, 20);
|
||||
heap.Alloc(buffer_c_, 50);
|
||||
@ -1117,7 +1117,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
|
||||
heap.Free(buffer_c_, 50);
|
||||
heap.Free(buffer_d_, 40);
|
||||
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(120, result.heap_size);
|
||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||
@ -1148,7 +1148,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
|
||||
// | | |
|
||||
// | +-------+
|
||||
// ---------------------> time
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||
heap.Alloc(buffer_a_, 10);
|
||||
heap.Alloc(buffer_b_, 20);
|
||||
heap.Alloc(buffer_c_, 40);
|
||||
@ -1160,7 +1160,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
|
||||
heap.Free(buffer_d_, 30);
|
||||
heap.Free(buffer_e_, 50);
|
||||
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(140, result.heap_size);
|
||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||
@ -1184,7 +1184,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
|
||||
// || |+----+| |
|
||||
// |+--a---++-b--++---c---+
|
||||
// ---------------------> time
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||
heap.Alloc(buffer_a_, 40);
|
||||
heap.Free(buffer_a_, 40);
|
||||
heap.Alloc(buffer_b_, 20);
|
||||
@ -1192,7 +1192,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
|
||||
heap.ShareWith(buffer_c_, buffer_a_, 40);
|
||||
heap.Free(buffer_c_, 40);
|
||||
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(40, result.heap_size);
|
||||
EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
|
||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||
@ -1212,7 +1212,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
|
||||
// || | | | <--- colocate with a
|
||||
// |+--a---+ +---c---+
|
||||
// ---------------------> time
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||
heap.Alloc(buffer_a_, 40);
|
||||
heap.Free(buffer_a_, 40);
|
||||
heap.Alloc(buffer_b_, 20);
|
||||
@ -1221,7 +1221,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
|
||||
heap.Free(buffer_c_, 40);
|
||||
heap.Free(buffer_b_, 20);
|
||||
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(60, result.heap_size);
|
||||
EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
|
||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||
@ -1242,7 +1242,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
|
||||
// | | |
|
||||
// | +-------b-------+
|
||||
// ---------------------> time
|
||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||
heap.Alloc(buffer_a_, 10);
|
||||
heap.Free(buffer_a_, 10);
|
||||
heap.Alloc(buffer_b_, 30);
|
||||
@ -1251,7 +1251,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
|
||||
heap.Free(buffer_c_, 10);
|
||||
heap.Free(buffer_b_, 30);
|
||||
|
||||
const HeapSimulator::Result result = heap.Finish();
|
||||
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||
EXPECT_EQ(40, result.heap_size);
|
||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||
EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
|
||||
|
@ -50,9 +50,9 @@ int64 PeakMemoryUseOfEntryComputation(
|
||||
|
||||
HloComputation* computation = module->entry_computation();
|
||||
const HloInstructionSequence& sequence = schedule.sequence(computation);
|
||||
return HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
|
||||
*computation, sequence, *alias_analysis,
|
||||
size_function)
|
||||
return HeapSimulator::Run(
|
||||
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(),
|
||||
*computation, sequence, *alias_analysis, size_function)
|
||||
.ValueOrDie()
|
||||
.heap_size;
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
||||
}
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
||||
const HloInstruction& defining_instruction =
|
||||
*interval.buffer->defining_instruction();
|
||||
@ -570,7 +570,8 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
|
||||
|
||||
absl::optional<float>
|
||||
CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
|
||||
const {
|
||||
return cost_analysis_.GetMemoryBoundedness(interval);
|
||||
}
|
||||
|
||||
@ -733,9 +734,9 @@ void AlternateMemoryBestFitHeap::FindAliases(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
|
||||
std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
|
||||
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
|
||||
std::vector<const BufferInterval*> colocated_intervals;
|
||||
std::vector<const BufferInterval*> worklist = {&interval};
|
||||
while (!worklist.empty()) {
|
||||
@ -864,7 +865,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
const AlternateMemoryBestFitHeap::BufferInterval& interval,
|
||||
std::string* debug_str) const {
|
||||
// Columns in buffer information:
|
||||
// buffer_id: int. This value can be used to match the allocation in
|
||||
@ -954,7 +955,7 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
|
||||
options_.dump_fn("allocinfo", allocation_info_str_);
|
||||
}
|
||||
|
||||
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
|
||||
std::vector<BufferInterval> sorted_buffer_intervals =
|
||||
GetSortedBufferIntervals();
|
||||
|
||||
@ -1390,10 +1391,10 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
|
||||
MemorySpaceAssignment::Allocation* last_allocation =
|
||||
allocations_->at(1).get();
|
||||
CHECK(last_allocation->memory_space() == MemorySpace::kAlternate);
|
||||
repack_allocation_blocks_.push_back(RepackAllocationBlock(
|
||||
repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
|
||||
last_allocation->start_time(), last_allocation->end_time(),
|
||||
last_allocation->chunk().size, last_allocation->chunk().offset,
|
||||
last_allocation));
|
||||
static_cast<int64>(repack_allocation_blocks_.size()), last_allocation));
|
||||
repack_allocation_blocks_.back().colocations.push_back(
|
||||
&repack_allocation_blocks_.back());
|
||||
|
||||
@ -1671,10 +1672,12 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations(
|
||||
std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
|
||||
for (MemorySpaceAssignment::Allocation* colocated_allocation :
|
||||
colocation.second) {
|
||||
repack_allocation_blocks_.push_back(RepackAllocationBlock(
|
||||
repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
|
||||
colocated_allocation->start_time(), colocated_allocation->end_time(),
|
||||
colocated_allocation->chunk().size,
|
||||
colocated_allocation->chunk().offset, colocated_allocation));
|
||||
colocated_allocation->chunk().offset,
|
||||
static_cast<int64>(repack_allocation_blocks_.size()),
|
||||
colocated_allocation));
|
||||
colocations.push_back(&repack_allocation_blocks_.back());
|
||||
}
|
||||
for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
|
||||
@ -2369,8 +2372,8 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
return x_memory_boundedness > y_memory_boundedness;
|
||||
}
|
||||
// Tie-break if the memory boundedness is the same.
|
||||
return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()(
|
||||
x, y);
|
||||
return GlobalDecreasingSizeBestFitHeap<
|
||||
HloValue>::GetSpatialBufferIntervalCompare()(x, y);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -106,7 +106,7 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||
// BufferInterval. The larger this number, the higher priority it will be
|
||||
// placed in the alternate memory.
|
||||
float GetMemoryBoundedness(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
|
||||
Cache* cache = nullptr) const;
|
||||
|
||||
// Returns the elapsed time in seconds due to compute only.
|
||||
@ -235,7 +235,8 @@ class PrefetchIntervalPicker {
|
||||
// of placing the BufferInterval in the alternate memory. The larger value,
|
||||
// the more beneficial.
|
||||
virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
|
||||
const {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
@ -324,7 +325,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
int64 end_time) const override;
|
||||
|
||||
absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval)
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
|
||||
const override;
|
||||
|
||||
private:
|
||||
@ -370,9 +371,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
class MemorySpaceAssignment {
|
||||
public:
|
||||
using Chunk = HeapSimulator::Chunk;
|
||||
using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval;
|
||||
using BufferInterval =
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval;
|
||||
using BufferIntervalCompare =
|
||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare;
|
||||
GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare;
|
||||
using IsAllowedInAlternateMemoryFunction =
|
||||
std::function<bool(const HloValue&)>;
|
||||
|
||||
@ -913,7 +915,8 @@ class AsynchronousCopyOrdering {
|
||||
|
||||
// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
|
||||
// maximum size.
|
||||
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
class AlternateMemoryBestFitHeap
|
||||
: public GlobalDecreasingSizeBestFitHeap<HloValue> {
|
||||
public:
|
||||
using MemorySpace = MemorySpaceAssignment::MemorySpace;
|
||||
using AllocationValue = MemorySpaceAssignment::AllocationValue;
|
||||
@ -940,25 +943,13 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
void AllocateCrossProgramPrefetchBuffer(
|
||||
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
|
||||
|
||||
HeapSimulator::Result Finish() override;
|
||||
HeapSimulator::Result<HloValue> Finish() override;
|
||||
|
||||
private:
|
||||
// We inherit AllocationBlock struct to attach the Allocation information to
|
||||
// make importing repacked offsets easier.
|
||||
struct RepackAllocationBlock
|
||||
: MemorySpaceAssignmentRepacker::AllocationBlock {
|
||||
RepackAllocationBlock(int64 start_time, int64 end_time, int64 size,
|
||||
int64 initial_offset,
|
||||
MemorySpaceAssignment::Allocation* allocation) {
|
||||
this->start_time = start_time;
|
||||
this->end_time = end_time;
|
||||
this->size = size;
|
||||
this->offset = -1;
|
||||
this->initial_offset = initial_offset;
|
||||
this->colocations = {};
|
||||
this->allocation = allocation;
|
||||
}
|
||||
|
||||
MemorySpaceAssignment::Allocation* allocation;
|
||||
};
|
||||
|
||||
@ -1231,6 +1222,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
return options_.max_size_in_bytes - reserved_in_bytes_;
|
||||
}
|
||||
|
||||
// Creates and returns a RepackAllocationBlock.
|
||||
static RepackAllocationBlock MakeRepackAllocationBlock(
|
||||
int64 start_time, int64 end_time, int64 size, int64 initial_offset,
|
||||
int64 id, MemorySpaceAssignment::Allocation* allocation) {
|
||||
RepackAllocationBlock allocation_block;
|
||||
allocation_block.start_time = start_time;
|
||||
allocation_block.end_time = end_time;
|
||||
allocation_block.size = size;
|
||||
allocation_block.offset = -1;
|
||||
allocation_block.initial_offset = initial_offset;
|
||||
allocation_block.id = id;
|
||||
allocation_block.colocations = {};
|
||||
allocation_block.allocation = allocation;
|
||||
return allocation_block;
|
||||
}
|
||||
|
||||
MemorySpaceAssignment::AllocationSequence* allocations_;
|
||||
const MemorySpaceAssignment::Options& options_;
|
||||
const HloAliasAnalysis& alias_analysis_;
|
||||
|
@ -33,14 +33,26 @@ class MemorySpaceAssignmentRepacker {
|
||||
// successful and the allocations were modified, the offset field holds the
|
||||
// new offset. To support aliased allocations, AllocationBlock also includes a
|
||||
// vector of AllocationBlock pointers, called colocations. All AllocationBlock
|
||||
// objects within the colocations must get the same offset.
|
||||
// objects within the colocations must get the same offset. The id should be
|
||||
// unique and is used to ensure determinism for comparison tie-breaker.
|
||||
struct AllocationBlock {
|
||||
int64 start_time;
|
||||
int64 end_time;
|
||||
int64 size;
|
||||
int64 offset;
|
||||
int64 initial_offset;
|
||||
int64 id;
|
||||
std::vector<AllocationBlock*> colocations;
|
||||
|
||||
std::string ToString() const {
|
||||
return absl::StrCat("[", start_time, ", ", end_time, "] : size = ", size,
|
||||
", offset = ", offset,
|
||||
" initial offset = ", initial_offset);
|
||||
}
|
||||
|
||||
// This is required by BufferIntervalCompare as a tie breaker. Use a unique
|
||||
// and deterministic id.
|
||||
bool operator<(const AllocationBlock& other) const { return id < other.id; }
|
||||
};
|
||||
|
||||
// Repack the AllocationBlocks provided in the parameter. Returns true if
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||
|
@ -26,7 +26,8 @@ class MemorySpaceAssignmentUtils {
|
||||
// Returns true if this buffer is allowed to be placed in the alternate
|
||||
// memory.
|
||||
static bool IsIntervalAllowedInAlternateMemory(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval);
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval&
|
||||
interval);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user