[XLA] Use templates in heap simulator to allow opaque type to be different than HloValue (NFC)
This CL allows reusing the heap algorithm machinery for opaque types other than HloValue. This is in preparation for using heap algorithms as memory space assignment repackers to reduce fragmentation of the alternate memory. PiperOrigin-RevId: 326745711 Change-Id: I30845956ee22a1958eb7ea39a9653f1cefa7691b
This commit is contained in:
parent
73b40908a4
commit
ba58b8cafa
@ -1431,6 +1431,7 @@ cc_library(
|
|||||||
":hlo_live_range",
|
":hlo_live_range",
|
||||||
":hlo_ordering",
|
":hlo_ordering",
|
||||||
":hlo_proto_cc",
|
":hlo_proto_cc",
|
||||||
|
":memory_space_assignment_repacking",
|
||||||
":tuple_points_to_analysis",
|
":tuple_points_to_analysis",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
|
@ -1424,13 +1424,16 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
|||||||
// Returns a heap algorithm that chooses the best result from several
|
// Returns a heap algorithm that chooses the best result from several
|
||||||
// algorithms.
|
// algorithms.
|
||||||
auto get_heap_algorithm = [&](int64 alignment) {
|
auto get_heap_algorithm = [&](int64 alignment) {
|
||||||
auto algorithms =
|
auto algorithms = absl::make_unique<
|
||||||
absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
|
std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
|
||||||
algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
|
algorithms->push_back(
|
||||||
alignment, GlobalDecreasingSizeBestFitHeap::kSpatial));
|
absl::make_unique<GlobalDecreasingSizeBestFitHeap<HloValue>>(
|
||||||
algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
|
alignment, GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
|
||||||
alignment, GlobalDecreasingSizeBestFitHeap::kTemporal));
|
algorithms->push_back(
|
||||||
return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
|
absl::make_unique<GlobalDecreasingSizeBestFitHeap<HloValue>>(
|
||||||
|
alignment, GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
|
||||||
|
return absl::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
|
||||||
|
std::move(algorithms));
|
||||||
};
|
};
|
||||||
|
|
||||||
if (run_whole_module_heap_simulation) {
|
if (run_whole_module_heap_simulation) {
|
||||||
@ -1461,7 +1464,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
|||||||
options.buffers_to_assign = &single_colored_set.second;
|
options.buffers_to_assign = &single_colored_set.second;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HeapSimulator::Result result,
|
HeapSimulator::Result<HloValue> result,
|
||||||
HeapSimulator::Run(
|
HeapSimulator::Run(
|
||||||
get_heap_algorithm(alignment), assignment->module(), schedule,
|
get_heap_algorithm(alignment), assignment->module(), schedule,
|
||||||
assignment->alias_analysis(), assignment->buffer_size_, options));
|
assignment->alias_analysis(), assignment->buffer_size_, options));
|
||||||
@ -1487,7 +1490,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
|||||||
HeapSimulator::Options options;
|
HeapSimulator::Options options;
|
||||||
options.buffers_to_assign = &single_colored_set.second;
|
options.buffers_to_assign = &single_colored_set.second;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HeapSimulator::Result result,
|
HeapSimulator::Result<HloValue> result,
|
||||||
HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
|
HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
|
||||||
*instruction_sequence,
|
*instruction_sequence,
|
||||||
assignment->alias_analysis(),
|
assignment->alias_analysis(),
|
||||||
@ -1582,7 +1585,7 @@ std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void BufferAssigner::AssignBuffersFromHeapSimulator(
|
void BufferAssigner::AssignBuffersFromHeapSimulator(
|
||||||
const HeapSimulator::Result& result, BufferAssignment* assignment,
|
const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
|
||||||
BufferValue::Color color) {
|
BufferValue::Color color) {
|
||||||
if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
|
if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
|
||||||
assignment->stats_.preallocated_temp_fragmentation_bytes =
|
assignment->stats_.preallocated_temp_fragmentation_bytes =
|
||||||
|
@ -661,9 +661,9 @@ class BufferAssigner {
|
|||||||
|
|
||||||
// Uses the results of the heap simulator to create a single allocation, with
|
// Uses the results of the heap simulator to create a single allocation, with
|
||||||
// LogicalBuffers packed to specific offsets.
|
// LogicalBuffers packed to specific offsets.
|
||||||
void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
|
void AssignBuffersFromHeapSimulator(
|
||||||
BufferAssignment* assignment,
|
const HeapSimulator::Result<HloValue>& result,
|
||||||
LogicalBuffer::Color color);
|
BufferAssignment* assignment, LogicalBuffer::Color color);
|
||||||
|
|
||||||
// Tries to assign the given instruction to the given buffer. Returns if the
|
// Tries to assign the given instruction to the given buffer. Returns if the
|
||||||
// assignment was successful.
|
// assignment was successful.
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_live_range.h"
|
#include "tensorflow/compiler/xla/service/hlo_live_range.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -55,9 +56,10 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
|
|||||||
// rather than summing each computation, since it gives us a better lower
|
// rather than summing each computation, since it gives us a better lower
|
||||||
// bound, by minimizing the liveness of sub-computations.
|
// bound, by minimizing the liveness of sub-computations.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HeapSimulator::Result result,
|
HeapSimulator::Result<HloValue> result,
|
||||||
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
|
HeapSimulator::Run(
|
||||||
schedule, *alias_analysis, size_function));
|
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), *module,
|
||||||
|
schedule, *alias_analysis, size_function));
|
||||||
return result.heap_size;
|
return result.heap_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,10 +71,11 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
|
|||||||
const absl::flat_hash_map<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HeapSimulator::Result result,
|
HeapSimulator::Result<HloValue> result,
|
||||||
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
|
HeapSimulator::Run(
|
||||||
computation, sequence, alias_analysis, size_function,
|
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
|
||||||
HeapSimulator::Options(), memory_by_computation));
|
sequence, alias_analysis, size_function, HeapSimulator::Options(),
|
||||||
|
memory_by_computation));
|
||||||
return result.heap_size;
|
return result.heap_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,16 +85,17 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
|
|||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const HloSchedule* schedule) {
|
const HloSchedule* schedule) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HeapSimulator::Result result,
|
HeapSimulator::Result<HloValue> result,
|
||||||
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
|
HeapSimulator::Run(
|
||||||
computation, sequence, alias_analysis, size_function,
|
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
|
||||||
schedule, HeapSimulator::Options()));
|
sequence, alias_analysis, size_function, schedule,
|
||||||
|
HeapSimulator::Options()));
|
||||||
return result.heap_size;
|
return result.heap_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
|
||||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
|
||||||
const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
|
const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
|
||||||
const BufferValue::SizeFunction& size_fn, const Options& options) {
|
const BufferValue::SizeFunction& size_fn, const Options& options) {
|
||||||
HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
|
HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
|
||||||
@ -108,8 +112,9 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
|
||||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
|
const HloComputation& computation,
|
||||||
const HloInstructionSequence& instruction_sequence,
|
const HloInstructionSequence& instruction_sequence,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const BufferValue::SizeFunction& size_fn, const Options& options,
|
const BufferValue::SizeFunction& size_fn, const Options& options,
|
||||||
@ -128,8 +133,9 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
|
||||||
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
|
const HloComputation& computation,
|
||||||
const HloInstructionSequence& instruction_sequence,
|
const HloInstructionSequence& instruction_sequence,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
|
const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
|
||||||
@ -326,12 +332,13 @@ Status HeapSimulator::RunComputation(
|
|||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::HeapSimulator(
|
HeapSimulator::HeapSimulator(
|
||||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
const BufferValue::SizeFunction& size_fn, const Options& options,
|
const BufferValue::SizeFunction& size_fn, const Options& options,
|
||||||
const HloSchedule* schedule,
|
const HloSchedule* schedule,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation)
|
memory_by_computation)
|
||||||
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
|
: no_fragmentation_stats_(
|
||||||
|
absl::make_unique<NoFragmentationStatsHeap<HloValue>>()),
|
||||||
algorithm_(std::move(algorithm)),
|
algorithm_(std::move(algorithm)),
|
||||||
size_fn_(size_fn),
|
size_fn_(size_fn),
|
||||||
options_(options),
|
options_(options),
|
||||||
@ -396,8 +403,8 @@ void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
|
|||||||
shared);
|
shared);
|
||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::Result HeapSimulator::Finish() {
|
HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
|
||||||
Result result = algorithm_->Finish();
|
Result<HloValue> result = algorithm_->Finish();
|
||||||
|
|
||||||
// Post-process the result to add chunks for shared buffers. An empty chunk
|
// Post-process the result to add chunks for shared buffers. An empty chunk
|
||||||
// map means that either no buffers were allocated, or the heap was only
|
// map means that either no buffers were allocated, or the heap was only
|
||||||
@ -411,7 +418,7 @@ HeapSimulator::Result HeapSimulator::Finish() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fragmentation is the difference between the actual and ideal sizes.
|
// Fragmentation is the difference between the actual and ideal sizes.
|
||||||
const Result no_frag_result = no_fragmentation_stats_->Finish();
|
const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
|
||||||
result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
|
result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
|
||||||
|
|
||||||
// Copy the debug trace we collected to the final result.
|
// Copy the debug trace we collected to the final result.
|
||||||
@ -437,14 +444,17 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NoFragmentationStatsHeap::Alloc(const HloValue* buffer, int64 size) {
|
template <typename BufferType>
|
||||||
|
void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
|
||||||
|
int64 size) {
|
||||||
current_heap_size_ += size;
|
current_heap_size_ += size;
|
||||||
if (current_heap_size_ > max_heap_size_) {
|
if (current_heap_size_ > max_heap_size_) {
|
||||||
max_heap_size_ = current_heap_size_;
|
max_heap_size_ = current_heap_size_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
template <typename BufferType>
|
||||||
|
void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction, int64 alloc_size_by_instruction,
|
const HloInstruction* instruction, int64 alloc_size_by_instruction,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
@ -472,11 +482,15 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
|||||||
std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
|
std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NoFragmentationStatsHeap::Free(const HloValue* buffer, int64 size) {
|
template <typename BufferType>
|
||||||
|
void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
|
||||||
|
int64 size) {
|
||||||
current_heap_size_ -= size;
|
current_heap_size_ -= size;
|
||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
|
template <typename BufferType>
|
||||||
|
HeapSimulator::Result<BufferType>
|
||||||
|
NoFragmentationStatsHeap<BufferType>::Finish() {
|
||||||
// The result.chunk_map is empty, since we only collect stats, and don't
|
// The result.chunk_map is empty, since we only collect stats, and don't
|
||||||
// actually compute chunk assignments.
|
// actually compute chunk assignments.
|
||||||
Result result;
|
Result result;
|
||||||
@ -484,7 +498,8 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
|
template <typename BufferType>
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
|
||||||
int64 alignment, Type type)
|
int64 alignment, Type type)
|
||||||
: alignment_(alignment) {
|
: alignment_(alignment) {
|
||||||
if (type == kTemporal) {
|
if (type == kTemporal) {
|
||||||
@ -495,8 +510,10 @@ GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
|
template <typename BufferType>
|
||||||
GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
|
typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
|
||||||
|
const {
|
||||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||||
int64 x_end = x.end;
|
int64 x_end = x.end;
|
||||||
for (auto colocation : GetTransitiveColocations(x)) {
|
for (auto colocation : GetTransitiveColocations(x)) {
|
||||||
@ -515,12 +532,14 @@ GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
|
|||||||
if (x.size != y.size) {
|
if (x.size != y.size) {
|
||||||
return x.size > y.size;
|
return x.size > y.size;
|
||||||
}
|
}
|
||||||
return x.buffer->id() < y.buffer->id();
|
return *x.buffer < *y.buffer;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
|
template <typename BufferType>
|
||||||
GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
|
/*static*/ typename GlobalDecreasingSizeBestFitHeap<
|
||||||
|
BufferType>::BufferIntervalCompare
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
|
||||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||||
if (x.size != y.size) {
|
if (x.size != y.size) {
|
||||||
return x.size > y.size;
|
return x.size > y.size;
|
||||||
@ -528,12 +547,13 @@ GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
|
|||||||
if (x.end - x.start != y.end - y.start) {
|
if (x.end - x.start != y.end - y.start) {
|
||||||
return x.end - x.start > y.end - y.start;
|
return x.end - x.start > y.end - y.start;
|
||||||
}
|
}
|
||||||
return x.buffer->id() < y.buffer->id();
|
return *x.buffer < *y.buffer;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
|
template <typename BufferType>
|
||||||
int64 size) {
|
void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
|
||||||
|
const BufferType* buffer, int64 size) {
|
||||||
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
result_.chunk_map.emplace(buffer, Chunk{0, 0});
|
result_.chunk_map.emplace(buffer, Chunk{0, 0});
|
||||||
@ -546,9 +566,9 @@ void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
|
|||||||
++current_time_;
|
++current_time_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer,
|
template <typename BufferType>
|
||||||
const HloValue* share_with,
|
void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
|
||||||
int64 size) {
|
const BufferType* buffer, const BufferType* share_with, int64 size) {
|
||||||
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
result_.chunk_map.emplace(buffer, Chunk{0, 0});
|
result_.chunk_map.emplace(buffer, Chunk{0, 0});
|
||||||
@ -562,15 +582,16 @@ void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer,
|
|||||||
++current_time_;
|
++current_time_;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::flat_hash_set<const HloValue*>
|
template <typename BufferType>
|
||||||
GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations(
|
absl::flat_hash_set<const BufferType*>
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
|
||||||
const BufferInterval& interval) const {
|
const BufferInterval& interval) const {
|
||||||
absl::flat_hash_set<const HloValue*> result;
|
absl::flat_hash_set<const BufferType*> result;
|
||||||
std::vector<const BufferInterval*> worklist = {&interval};
|
std::vector<const BufferInterval*> worklist = {&interval};
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
const BufferInterval* item = worklist.back();
|
const BufferInterval* item = worklist.back();
|
||||||
worklist.pop_back();
|
worklist.pop_back();
|
||||||
for (const HloValue* buffer_colocated : item->colocations) {
|
for (const BufferType* buffer_colocated : item->colocations) {
|
||||||
result.insert(buffer_colocated);
|
result.insert(buffer_colocated);
|
||||||
worklist.push_back(&buffer_intervals_.at(buffer_colocated));
|
worklist.push_back(&buffer_intervals_.at(buffer_colocated));
|
||||||
}
|
}
|
||||||
@ -579,7 +600,9 @@ GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
|
template <typename BufferType>
|
||||||
|
void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
|
||||||
|
int64 size) {
|
||||||
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
// Degenerate case: 0-sized buffers are always allocated at offset 0.
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
return;
|
return;
|
||||||
@ -785,7 +808,9 @@ std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
template <typename BufferType>
|
||||||
|
HeapSimulator::Result<BufferType>
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
|
||||||
std::vector<BufferInterval> sorted_buffer_intervals =
|
std::vector<BufferInterval> sorted_buffer_intervals =
|
||||||
GetSortedBufferIntervals();
|
GetSortedBufferIntervals();
|
||||||
|
|
||||||
@ -803,8 +828,10 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
|||||||
return result_;
|
return result_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GlobalDecreasingSizeBestFitHeap::BufferInterval>
|
template <typename BufferType>
|
||||||
GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const {
|
std::vector<
|
||||||
|
typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
|
||||||
std::vector<BufferInterval> sorted_buffer_intervals;
|
std::vector<BufferInterval> sorted_buffer_intervals;
|
||||||
for (auto& entry : buffer_intervals_) {
|
for (auto& entry : buffer_intervals_) {
|
||||||
sorted_buffer_intervals.push_back(entry.second);
|
sorted_buffer_intervals.push_back(entry.second);
|
||||||
@ -814,8 +841,9 @@ GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const {
|
|||||||
return sorted_buffer_intervals;
|
return sorted_buffer_intervals;
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalDecreasingSizeBestFitHeap::ChunkCandidate
|
template <typename BufferType>
|
||||||
GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
|
typename GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
|
const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
|
||||||
int64 preferred_offset) const {
|
int64 preferred_offset) const {
|
||||||
VLOG(1) << "Finding chunks for buffer: "
|
VLOG(1) << "Finding chunks for buffer: "
|
||||||
@ -912,9 +940,12 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
|
|||||||
return chunk_candidate;
|
return chunk_candidate;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalDecreasingSizeBestFitHeap::CommitChunk(
|
template <typename BufferType>
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
|
void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
|
||||||
GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate) {
|
const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
|
||||||
|
buffer_interval,
|
||||||
|
GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
|
||||||
|
chunk_candidate) {
|
||||||
// Update the maximum heap size according to the one determined by the chunk
|
// Update the maximum heap size according to the one determined by the chunk
|
||||||
// candidate.
|
// candidate.
|
||||||
result_.heap_size = chunk_candidate.heap_size;
|
result_.heap_size = chunk_candidate.heap_size;
|
||||||
@ -930,13 +961,16 @@ void GlobalDecreasingSizeBestFitHeap::CommitChunk(
|
|||||||
AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
|
AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer,
|
template <typename BufferType>
|
||||||
Chunk chunk) {
|
void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
|
||||||
|
const BufferType* buffer, Chunk chunk) {
|
||||||
const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
|
const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
|
||||||
DCHECK(emplace_result.second);
|
DCHECK(emplace_result.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
|
template <typename BufferType>
|
||||||
|
HeapSimulator::Result<BufferType>
|
||||||
|
ChooseBestHeapAlgorithm<BufferType>::Finish() {
|
||||||
DCHECK(!algorithms_.empty());
|
DCHECK(!algorithms_.empty());
|
||||||
std::vector<Result> results(algorithms_.size());
|
std::vector<Result> results(algorithms_.size());
|
||||||
int64 min_size = INT64_MAX;
|
int64 min_size = INT64_MAX;
|
||||||
@ -953,4 +987,9 @@ HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
|
|||||||
return results[min_size_index];
|
return results[min_size_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template class GlobalDecreasingSizeBestFitHeap<HloValue>;
|
||||||
|
template class GlobalDecreasingSizeBestFitHeap<
|
||||||
|
MemorySpaceAssignmentRepacker::AllocationBlock>;
|
||||||
|
template class ChooseBestHeapAlgorithm<HloValue>;
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -40,7 +40,9 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Forward declare classes defined below.
|
// Forward declare classes defined below.
|
||||||
|
template <typename BufferType>
|
||||||
class HeapAlgorithm;
|
class HeapAlgorithm;
|
||||||
|
template <typename BufferType>
|
||||||
class NoFragmentationStatsHeap;
|
class NoFragmentationStatsHeap;
|
||||||
|
|
||||||
// HeapSimulator assigns buffer offsets by running a simulation of a regular
|
// HeapSimulator assigns buffer offsets by running a simulation of a regular
|
||||||
@ -66,9 +68,10 @@ class HeapSimulator {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Result represents the result of the heap simulation.
|
// Result represents the result of the heap simulation.
|
||||||
|
template <typename BufferType>
|
||||||
struct Result {
|
struct Result {
|
||||||
// The assignment of buffers to chunks.
|
// The assignment of buffers to chunks.
|
||||||
absl::flat_hash_map<const HloValue*, Chunk> chunk_map;
|
absl::flat_hash_map<const BufferType*, Chunk> chunk_map;
|
||||||
|
|
||||||
// The total size in bytes of the heap, containing all assigned chunks.
|
// The total size in bytes of the heap, containing all assigned chunks.
|
||||||
int64 heap_size = 0;
|
int64 heap_size = 0;
|
||||||
@ -128,19 +131,19 @@ class HeapSimulator {
|
|||||||
// to running on a per-computation basis, since we can re-use buffer space for
|
// to running on a per-computation basis, since we can re-use buffer space for
|
||||||
// called sub-computations.
|
// called sub-computations.
|
||||||
//
|
//
|
||||||
static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm,
|
static StatusOr<Result<HloValue>> Run(
|
||||||
const HloModule& module,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
const HloSchedule& schedule,
|
const HloModule& module, const HloSchedule& schedule,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const BufferValue::SizeFunction& size_fn,
|
const BufferValue::SizeFunction& size_fn,
|
||||||
const Options& options = Options());
|
const Options& options = Options());
|
||||||
|
|
||||||
// Same as above, but runs on a single computation. The 'instruction_sequence'
|
// Same as above, but runs on a single computation. The 'instruction_sequence'
|
||||||
// must contain a topologically-consistent total ordering of all instructions
|
// must contain a topologically-consistent total ordering of all instructions
|
||||||
// in the computation. The result is invalid if instructions are not run in
|
// in the computation. The result is invalid if instructions are not run in
|
||||||
// exactly this sequence.
|
// exactly this sequence.
|
||||||
static StatusOr<Result> Run(
|
static StatusOr<Result<HloValue>> Run(
|
||||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const HloInstructionSequence& instruction_sequence,
|
const HloInstructionSequence& instruction_sequence,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const HloAliasAnalysis& alias_analysis,
|
||||||
@ -151,8 +154,8 @@ class HeapSimulator {
|
|||||||
|
|
||||||
// Same as above, but runs on with a schedule that covers all nested
|
// Same as above, but runs on with a schedule that covers all nested
|
||||||
// computations.
|
// computations.
|
||||||
static StatusOr<Result> Run(
|
static StatusOr<Result<HloValue>> Run(
|
||||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const HloInstructionSequence& instruction_sequence,
|
const HloInstructionSequence& instruction_sequence,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const HloAliasAnalysis& alias_analysis,
|
||||||
@ -163,7 +166,7 @@ class HeapSimulator {
|
|||||||
// If 'schedule' is non-null, it is used to find kCall and kWhile
|
// If 'schedule' is non-null, it is used to find kCall and kWhile
|
||||||
// sub-computations, and the heap simulation for those sub-computations will
|
// sub-computations, and the heap simulation for those sub-computations will
|
||||||
// be run recursively. I.e. the simulation is run over the whole module.
|
// be run recursively. I.e. the simulation is run over the whole module.
|
||||||
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
|
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
|
||||||
const BufferValue::SizeFunction& size_fn,
|
const BufferValue::SizeFunction& size_fn,
|
||||||
const Options& options, const HloSchedule* schedule = nullptr,
|
const Options& options, const HloSchedule* schedule = nullptr,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
@ -187,7 +190,7 @@ class HeapSimulator {
|
|||||||
// Two buffers belong to the same shared group.
|
// Two buffers belong to the same shared group.
|
||||||
// Eight of the buffer has no shared group assigned.
|
// Eight of the buffer has no shared group assigned.
|
||||||
bool InSameSharedGroup(const HloValue* left, const HloValue* right);
|
bool InSameSharedGroup(const HloValue* left, const HloValue* right);
|
||||||
Result Finish();
|
Result<HloValue> Finish();
|
||||||
|
|
||||||
void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
|
void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
|
||||||
const HloValue* buffer, const HloInstruction* instruction,
|
const HloValue* buffer, const HloInstruction* instruction,
|
||||||
@ -196,8 +199,9 @@ class HeapSimulator {
|
|||||||
// Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
|
// Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
|
||||||
// in which case we are calculating the same allocs/frees twice in the
|
// in which case we are calculating the same allocs/frees twice in the
|
||||||
// simulation.
|
// simulation.
|
||||||
const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_;
|
const std::unique_ptr<NoFragmentationStatsHeap<HloValue>>
|
||||||
const std::unique_ptr<HeapAlgorithm> algorithm_;
|
no_fragmentation_stats_;
|
||||||
|
const std::unique_ptr<HeapAlgorithm<HloValue>> algorithm_;
|
||||||
const BufferValue::SizeFunction size_fn_;
|
const BufferValue::SizeFunction size_fn_;
|
||||||
const Options options_;
|
const Options options_;
|
||||||
// schedule_ is set by buffer assignment, and memory_by_computation_ is
|
// schedule_ is set by buffer assignment, and memory_by_computation_ is
|
||||||
@ -220,15 +224,16 @@ class HeapSimulator {
|
|||||||
// offsets to buffers. A sequence of Alloc / Free calls will be made, with the
|
// offsets to buffers. A sequence of Alloc / Free calls will be made, with the
|
||||||
// same semantics as a regular memory heap. Finish will be called at the end to
|
// same semantics as a regular memory heap. Finish will be called at the end to
|
||||||
// collect the simulation results.
|
// collect the simulation results.
|
||||||
|
template <typename BufferType>
|
||||||
class HeapAlgorithm {
|
class HeapAlgorithm {
|
||||||
public:
|
public:
|
||||||
using Chunk = HeapSimulator::Chunk;
|
using Chunk = HeapSimulator::Chunk;
|
||||||
using Result = HeapSimulator::Result;
|
using Result = HeapSimulator::Result<BufferType>;
|
||||||
|
|
||||||
virtual ~HeapAlgorithm() = default;
|
virtual ~HeapAlgorithm() = default;
|
||||||
|
|
||||||
// Alloc allocates a buffer of 'size' bytes.
|
// Alloc allocates a buffer of 'size' bytes.
|
||||||
virtual void Alloc(const HloValue* buffer, int64 size) = 0;
|
virtual void Alloc(const BufferType* buffer, int64 size) = 0;
|
||||||
|
|
||||||
// Takes memory usage of subcomputations into account when calculating the
|
// Takes memory usage of subcomputations into account when calculating the
|
||||||
// memory usage of a computation. Currently, we don't handle buffer aliasing
|
// memory usage of a computation. Currently, we don't handle buffer aliasing
|
||||||
@ -247,7 +252,7 @@ class HeapAlgorithm {
|
|||||||
memory_by_computation) {}
|
memory_by_computation) {}
|
||||||
|
|
||||||
// Free de-allocates a previously allocated buffer.
|
// Free de-allocates a previously allocated buffer.
|
||||||
virtual void Free(const HloValue* buffer, int64 size) = 0;
|
virtual void Free(const BufferType* buffer, int64 size) = 0;
|
||||||
|
|
||||||
// Indicates that a buffer has to be collocated with another buffer. In
|
// Indicates that a buffer has to be collocated with another buffer. In
|
||||||
// addition to Alloc and Free, the heap simulator exposes a concept of buffer
|
// addition to Alloc and Free, the heap simulator exposes a concept of buffer
|
||||||
@ -255,7 +260,7 @@ class HeapAlgorithm {
|
|||||||
// the buffer, it associates the buffer with a previously allocated (or
|
// the buffer, it associates the buffer with a previously allocated (or
|
||||||
// shared) buffer. Each group of mutually-shared buffers points to a single
|
// shared) buffer. Each group of mutually-shared buffers points to a single
|
||||||
// SharedGroup instance, which is a shared control block.
|
// SharedGroup instance, which is a shared control block.
|
||||||
virtual void ShareWith(const HloValue* buffer, const HloValue* share_with,
|
virtual void ShareWith(const BufferType* buffer, const BufferType* share_with,
|
||||||
int64 size) {
|
int64 size) {
|
||||||
Alloc(buffer, size);
|
Alloc(buffer, size);
|
||||||
}
|
}
|
||||||
@ -269,19 +274,22 @@ class HeapAlgorithm {
|
|||||||
// this is the absolute minimum size for a given instruction sequence. The
|
// this is the absolute minimum size for a given instruction sequence. The
|
||||||
// result.chunk_map returned in Finish is always empty, since we only collect
|
// result.chunk_map returned in Finish is always empty, since we only collect
|
||||||
// stats, and don't actually compute chunk assignments.
|
// stats, and don't actually compute chunk assignments.
|
||||||
class NoFragmentationStatsHeap : public HeapAlgorithm {
|
template <typename BufferType>
|
||||||
|
class NoFragmentationStatsHeap : public HeapAlgorithm<BufferType> {
|
||||||
public:
|
public:
|
||||||
|
using Result = HeapSimulator::Result<BufferType>;
|
||||||
|
|
||||||
NoFragmentationStatsHeap() = default;
|
NoFragmentationStatsHeap() = default;
|
||||||
~NoFragmentationStatsHeap() override = default;
|
~NoFragmentationStatsHeap() override = default;
|
||||||
|
|
||||||
void Alloc(const HloValue* buffer, int64 size) override;
|
void Alloc(const BufferType* buffer, int64 size) override;
|
||||||
|
|
||||||
void AccountForSubcomputationMemory(
|
void AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction, int64 alloc_size_by_instruction,
|
const HloInstruction* instruction, int64 alloc_size_by_instruction,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) override;
|
memory_by_computation) override;
|
||||||
|
|
||||||
void Free(const HloValue* buffer, int64 size) override;
|
void Free(const BufferType* buffer, int64 size) override;
|
||||||
|
|
||||||
Result Finish() override;
|
Result Finish() override;
|
||||||
|
|
||||||
@ -336,8 +344,12 @@ class BufferIntervalTree {
|
|||||||
// alloc/free time. It internally tracks the allocated buffers and their live
|
// alloc/free time. It internally tracks the allocated buffers and their live
|
||||||
// intervals; when allocating a buffer, it finds the best-fit free chunk during
|
// intervals; when allocating a buffer, it finds the best-fit free chunk during
|
||||||
// its live interval.
|
// its live interval.
|
||||||
class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
template <typename BufferType>
|
||||||
|
class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm<BufferType> {
|
||||||
public:
|
public:
|
||||||
|
using Result = HeapSimulator::Result<BufferType>;
|
||||||
|
using Chunk = HeapSimulator::Chunk;
|
||||||
|
|
||||||
enum Type {
|
enum Type {
|
||||||
kSpatial = 0,
|
kSpatial = 0,
|
||||||
kTemporal,
|
kTemporal,
|
||||||
@ -345,7 +357,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
|
|
||||||
// BufferInterval stores a buffer's size and time interval.
|
// BufferInterval stores a buffer's size and time interval.
|
||||||
struct BufferInterval {
|
struct BufferInterval {
|
||||||
const HloValue* buffer;
|
const BufferType* buffer;
|
||||||
int64 size;
|
int64 size;
|
||||||
// Alloc time of the buffer.
|
// Alloc time of the buffer.
|
||||||
int64 start;
|
int64 start;
|
||||||
@ -353,7 +365,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
int64 end;
|
int64 end;
|
||||||
|
|
||||||
// Colocation buffers that need to be collocated with this one.
|
// Colocation buffers that need to be collocated with this one.
|
||||||
std::vector<const HloValue*> colocations;
|
std::vector<const BufferType*> colocations;
|
||||||
|
|
||||||
// True if this buffer needs an allocation. False if it is collocated with
|
// True if this buffer needs an allocation. False if it is collocated with
|
||||||
// other buffer.
|
// other buffer.
|
||||||
@ -368,10 +380,10 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
Type type = kSpatial);
|
Type type = kSpatial);
|
||||||
~GlobalDecreasingSizeBestFitHeap() override {}
|
~GlobalDecreasingSizeBestFitHeap() override {}
|
||||||
|
|
||||||
void Alloc(const HloValue* buffer, int64 size) override;
|
void Alloc(const BufferType* buffer, int64 size) override;
|
||||||
void Free(const HloValue* buffer, int64 size) override;
|
void Free(const BufferType* buffer, int64 size) override;
|
||||||
|
|
||||||
void ShareWith(const HloValue* buffer, const HloValue* share_with,
|
void ShareWith(const BufferType* buffer, const BufferType* share_with,
|
||||||
int64 size) override;
|
int64 size) override;
|
||||||
|
|
||||||
Result Finish() override;
|
Result Finish() override;
|
||||||
@ -404,7 +416,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
void CommitChunk(const BufferInterval& buffer_interval,
|
void CommitChunk(const BufferInterval& buffer_interval,
|
||||||
ChunkCandidate chunk_candidate);
|
ChunkCandidate chunk_candidate);
|
||||||
// Adds the buffer and the chunk to the result chunk map.
|
// Adds the buffer and the chunk to the result chunk map.
|
||||||
virtual void AddToChunkMap(const HloValue* buffer, Chunk chunk);
|
virtual void AddToChunkMap(const BufferType* buffer, Chunk chunk);
|
||||||
|
|
||||||
// Return a BufferIntervalCompare function that sorts by live ranges. A live
|
// Return a BufferIntervalCompare function that sorts by live ranges. A live
|
||||||
// range is defined by the range between the start of the first buffer and the
|
// range is defined by the range between the start of the first buffer and the
|
||||||
@ -413,7 +425,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
// contiguous.
|
// contiguous.
|
||||||
BufferIntervalCompare GetTemporalBufferIntervalCompare() const;
|
BufferIntervalCompare GetTemporalBufferIntervalCompare() const;
|
||||||
|
|
||||||
absl::flat_hash_map<const HloValue*, BufferInterval> buffer_intervals_;
|
absl::flat_hash_map<const BufferType*, BufferInterval> buffer_intervals_;
|
||||||
Result result_;
|
Result result_;
|
||||||
BufferIntervalCompare buffer_interval_compare_;
|
BufferIntervalCompare buffer_interval_compare_;
|
||||||
BufferIntervalTree interval_tree_;
|
BufferIntervalTree interval_tree_;
|
||||||
@ -428,33 +440,37 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
// Returns all transitive colocated buffers of this buffer interval. I.e., If
|
// Returns all transitive colocated buffers of this buffer interval. I.e., If
|
||||||
// a buffer A is colocated with B and B is colocated with C, this function
|
// a buffer A is colocated with B and B is colocated with C, this function
|
||||||
// returns all three of them.
|
// returns all three of them.
|
||||||
absl::flat_hash_set<const HloValue*> GetTransitiveColocations(
|
absl::flat_hash_set<const BufferType*> GetTransitiveColocations(
|
||||||
const BufferInterval& interval) const;
|
const BufferInterval& interval) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A heap algorithm that chooses the best results from other algorithms added to
|
// A heap algorithm that chooses the best results from other algorithms added to
|
||||||
// it.
|
// it.
|
||||||
class ChooseBestHeapAlgorithm : public HeapAlgorithm {
|
template <typename BufferType>
|
||||||
|
class ChooseBestHeapAlgorithm : public HeapAlgorithm<BufferType> {
|
||||||
public:
|
public:
|
||||||
|
using Result = HeapSimulator::Result<BufferType>;
|
||||||
|
|
||||||
ChooseBestHeapAlgorithm(
|
ChooseBestHeapAlgorithm(
|
||||||
std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm>>> algorithms)
|
std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm<BufferType>>>>
|
||||||
|
algorithms)
|
||||||
: algorithms_(std::move(*algorithms)) {}
|
: algorithms_(std::move(*algorithms)) {}
|
||||||
~ChooseBestHeapAlgorithm() override {}
|
~ChooseBestHeapAlgorithm() override {}
|
||||||
|
|
||||||
void Alloc(const HloValue* buffer, int64 size) override {
|
void Alloc(const BufferType* buffer, int64 size) override {
|
||||||
for (auto& algorithm : algorithms_) {
|
for (auto& algorithm : algorithms_) {
|
||||||
algorithm->Alloc(buffer, size);
|
algorithm->Alloc(buffer, size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShareWith(const HloValue* buffer, const HloValue* share_with,
|
void ShareWith(const BufferType* buffer, const BufferType* share_with,
|
||||||
int64 size) override {
|
int64 size) override {
|
||||||
for (auto& algorithm : algorithms_) {
|
for (auto& algorithm : algorithms_) {
|
||||||
algorithm->ShareWith(buffer, share_with, size);
|
algorithm->ShareWith(buffer, share_with, size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(const HloValue* buffer, int64 size) override {
|
void Free(const BufferType* buffer, int64 size) override {
|
||||||
for (auto& algorithm : algorithms_) {
|
for (auto& algorithm : algorithms_) {
|
||||||
algorithm->Free(buffer, size);
|
algorithm->Free(buffer, size);
|
||||||
}
|
}
|
||||||
@ -463,7 +479,7 @@ class ChooseBestHeapAlgorithm : public HeapAlgorithm {
|
|||||||
Result Finish() override;
|
Result Finish() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::unique_ptr<HeapAlgorithm>> algorithms_;
|
std::vector<std::unique_ptr<HeapAlgorithm<BufferType>>> algorithms_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -228,7 +228,7 @@ const char kFinish[] = "Finish";
|
|||||||
using CallSequence = std::vector<std::pair<string, const HloValue*>>;
|
using CallSequence = std::vector<std::pair<string, const HloValue*>>;
|
||||||
|
|
||||||
// HeapCallRecorder is a dummy heap algorithm that simply records its calls.
|
// HeapCallRecorder is a dummy heap algorithm that simply records its calls.
|
||||||
class HeapCallRecorder : public HeapAlgorithm {
|
class HeapCallRecorder : public HeapAlgorithm<HloValue> {
|
||||||
public:
|
public:
|
||||||
explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
|
explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
|
||||||
~HeapCallRecorder() override {}
|
~HeapCallRecorder() override {}
|
||||||
@ -396,7 +396,7 @@ class HeapSimulatorTracker {
|
|||||||
std::unique_ptr<HloModule> module_;
|
std::unique_ptr<HloModule> module_;
|
||||||
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
|
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
|
||||||
CallSequence actual_calls_;
|
CallSequence actual_calls_;
|
||||||
HeapSimulator::Result result_;
|
HeapSimulator::Result<HloValue> result_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HeapSimulatorTest : public HloTestBase {
|
class HeapSimulatorTest : public HloTestBase {
|
||||||
@ -976,12 +976,12 @@ class HeapAlgorithmTestBase : public ::testing::Test {
|
|||||||
class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
|
class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
|
||||||
|
|
||||||
TEST_F(NoFragmentationStatsHeapTest, Empty) {
|
TEST_F(NoFragmentationStatsHeapTest, Empty) {
|
||||||
NoFragmentationStatsHeap heap;
|
NoFragmentationStatsHeap<HloValue> heap;
|
||||||
EXPECT_EQ(0, heap.Finish().heap_size);
|
EXPECT_EQ(0, heap.Finish().heap_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoFragmentationStatsHeapTest, Simple) {
|
TEST_F(NoFragmentationStatsHeapTest, Simple) {
|
||||||
NoFragmentationStatsHeap heap;
|
NoFragmentationStatsHeap<HloValue> heap;
|
||||||
heap.Alloc(buffer_a_, 10);
|
heap.Alloc(buffer_a_, 10);
|
||||||
heap.Alloc(buffer_b_, 20);
|
heap.Alloc(buffer_b_, 20);
|
||||||
heap.Alloc(buffer_c_, 30);
|
heap.Alloc(buffer_c_, 30);
|
||||||
@ -994,7 +994,7 @@ TEST_F(NoFragmentationStatsHeapTest, Simple) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoFragmentationStatsHeapTest, Mixed) {
|
TEST_F(NoFragmentationStatsHeapTest, Mixed) {
|
||||||
NoFragmentationStatsHeap heap;
|
NoFragmentationStatsHeap<HloValue> heap;
|
||||||
heap.Alloc(buffer_a_, 10); // max: A
|
heap.Alloc(buffer_a_, 10); // max: A
|
||||||
|
|
||||||
heap.Alloc(buffer_b_, 20); // max: A+B
|
heap.Alloc(buffer_b_, 20); // max: A+B
|
||||||
@ -1013,7 +1013,7 @@ TEST_F(NoFragmentationStatsHeapTest, Mixed) {
|
|||||||
class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
|
class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
|
||||||
protected:
|
protected:
|
||||||
class InheritedGlobalDecreasingSizeBestFitHeap
|
class InheritedGlobalDecreasingSizeBestFitHeap
|
||||||
: public GlobalDecreasingSizeBestFitHeap {
|
: public GlobalDecreasingSizeBestFitHeap<HloValue> {
|
||||||
public:
|
public:
|
||||||
InheritedGlobalDecreasingSizeBestFitHeap()
|
InheritedGlobalDecreasingSizeBestFitHeap()
|
||||||
: GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
|
: GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
|
||||||
@ -1048,8 +1048,8 @@ class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
|
TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(0, result.heap_size);
|
EXPECT_EQ(0, result.heap_size);
|
||||||
EXPECT_EQ(0, result.chunk_map.size());
|
EXPECT_EQ(0, result.chunk_map.size());
|
||||||
}
|
}
|
||||||
@ -1068,7 +1068,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
|
|||||||
// | | d |
|
// | | d |
|
||||||
// | +-------+
|
// | +-------+
|
||||||
// -----------------> time
|
// -----------------> time
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||||
heap.Alloc(buffer_a_, 10);
|
heap.Alloc(buffer_a_, 10);
|
||||||
heap.Alloc(buffer_b_, 30);
|
heap.Alloc(buffer_b_, 30);
|
||||||
heap.Alloc(buffer_c_, 20);
|
heap.Alloc(buffer_c_, 20);
|
||||||
@ -1078,7 +1078,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
|
|||||||
heap.Free(buffer_c_, 20);
|
heap.Free(buffer_c_, 20);
|
||||||
heap.Free(buffer_d_, 40);
|
heap.Free(buffer_d_, 40);
|
||||||
|
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(100, result.heap_size);
|
EXPECT_EQ(100, result.heap_size);
|
||||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||||
EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
|
EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
|
||||||
@ -1107,7 +1107,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
|
|||||||
// | | |
|
// | | |
|
||||||
// | +-------+
|
// | +-------+
|
||||||
// ---------------------> time
|
// ---------------------> time
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/20);
|
||||||
heap.Alloc(buffer_a_, 10);
|
heap.Alloc(buffer_a_, 10);
|
||||||
heap.Alloc(buffer_b_, 20);
|
heap.Alloc(buffer_b_, 20);
|
||||||
heap.Alloc(buffer_c_, 50);
|
heap.Alloc(buffer_c_, 50);
|
||||||
@ -1117,7 +1117,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
|
|||||||
heap.Free(buffer_c_, 50);
|
heap.Free(buffer_c_, 50);
|
||||||
heap.Free(buffer_d_, 40);
|
heap.Free(buffer_d_, 40);
|
||||||
|
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(120, result.heap_size);
|
EXPECT_EQ(120, result.heap_size);
|
||||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||||
@ -1148,7 +1148,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
|
|||||||
// | | |
|
// | | |
|
||||||
// | +-------+
|
// | +-------+
|
||||||
// ---------------------> time
|
// ---------------------> time
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||||
heap.Alloc(buffer_a_, 10);
|
heap.Alloc(buffer_a_, 10);
|
||||||
heap.Alloc(buffer_b_, 20);
|
heap.Alloc(buffer_b_, 20);
|
||||||
heap.Alloc(buffer_c_, 40);
|
heap.Alloc(buffer_c_, 40);
|
||||||
@ -1160,7 +1160,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
|
|||||||
heap.Free(buffer_d_, 30);
|
heap.Free(buffer_d_, 30);
|
||||||
heap.Free(buffer_e_, 50);
|
heap.Free(buffer_e_, 50);
|
||||||
|
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(140, result.heap_size);
|
EXPECT_EQ(140, result.heap_size);
|
||||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||||
@ -1184,7 +1184,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
|
|||||||
// || |+----+| |
|
// || |+----+| |
|
||||||
// |+--a---++-b--++---c---+
|
// |+--a---++-b--++---c---+
|
||||||
// ---------------------> time
|
// ---------------------> time
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||||
heap.Alloc(buffer_a_, 40);
|
heap.Alloc(buffer_a_, 40);
|
||||||
heap.Free(buffer_a_, 40);
|
heap.Free(buffer_a_, 40);
|
||||||
heap.Alloc(buffer_b_, 20);
|
heap.Alloc(buffer_b_, 20);
|
||||||
@ -1192,7 +1192,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
|
|||||||
heap.ShareWith(buffer_c_, buffer_a_, 40);
|
heap.ShareWith(buffer_c_, buffer_a_, 40);
|
||||||
heap.Free(buffer_c_, 40);
|
heap.Free(buffer_c_, 40);
|
||||||
|
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(40, result.heap_size);
|
EXPECT_EQ(40, result.heap_size);
|
||||||
EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
|
EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
|
||||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||||
@ -1212,7 +1212,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
|
|||||||
// || | | | <--- colocate with a
|
// || | | | <--- colocate with a
|
||||||
// |+--a---+ +---c---+
|
// |+--a---+ +---c---+
|
||||||
// ---------------------> time
|
// ---------------------> time
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||||
heap.Alloc(buffer_a_, 40);
|
heap.Alloc(buffer_a_, 40);
|
||||||
heap.Free(buffer_a_, 40);
|
heap.Free(buffer_a_, 40);
|
||||||
heap.Alloc(buffer_b_, 20);
|
heap.Alloc(buffer_b_, 20);
|
||||||
@ -1221,7 +1221,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
|
|||||||
heap.Free(buffer_c_, 40);
|
heap.Free(buffer_c_, 40);
|
||||||
heap.Free(buffer_b_, 20);
|
heap.Free(buffer_b_, 20);
|
||||||
|
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(60, result.heap_size);
|
EXPECT_EQ(60, result.heap_size);
|
||||||
EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
|
EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
|
||||||
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
|
||||||
@ -1242,7 +1242,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
|
|||||||
// | | |
|
// | | |
|
||||||
// | +-------b-------+
|
// | +-------b-------+
|
||||||
// ---------------------> time
|
// ---------------------> time
|
||||||
GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
|
GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
|
||||||
heap.Alloc(buffer_a_, 10);
|
heap.Alloc(buffer_a_, 10);
|
||||||
heap.Free(buffer_a_, 10);
|
heap.Free(buffer_a_, 10);
|
||||||
heap.Alloc(buffer_b_, 30);
|
heap.Alloc(buffer_b_, 30);
|
||||||
@ -1251,7 +1251,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
|
|||||||
heap.Free(buffer_c_, 10);
|
heap.Free(buffer_c_, 10);
|
||||||
heap.Free(buffer_b_, 30);
|
heap.Free(buffer_b_, 30);
|
||||||
|
|
||||||
const HeapSimulator::Result result = heap.Finish();
|
const HeapSimulator::Result<HloValue> result = heap.Finish();
|
||||||
EXPECT_EQ(40, result.heap_size);
|
EXPECT_EQ(40, result.heap_size);
|
||||||
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
|
||||||
EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
|
EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
|
||||||
|
@ -50,9 +50,9 @@ int64 PeakMemoryUseOfEntryComputation(
|
|||||||
|
|
||||||
HloComputation* computation = module->entry_computation();
|
HloComputation* computation = module->entry_computation();
|
||||||
const HloInstructionSequence& sequence = schedule.sequence(computation);
|
const HloInstructionSequence& sequence = schedule.sequence(computation);
|
||||||
return HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
|
return HeapSimulator::Run(
|
||||||
*computation, sequence, *alias_analysis,
|
absl::make_unique<NoFragmentationStatsHeap<HloValue>>(),
|
||||||
size_function)
|
*computation, sequence, *alias_analysis, size_function)
|
||||||
.ValueOrDie()
|
.ValueOrDie()
|
||||||
.heap_size;
|
.heap_size;
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
|
||||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
||||||
const HloInstruction& defining_instruction =
|
const HloInstruction& defining_instruction =
|
||||||
*interval.buffer->defining_instruction();
|
*interval.buffer->defining_instruction();
|
||||||
@ -570,7 +570,8 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
|
|||||||
|
|
||||||
absl::optional<float>
|
absl::optional<float>
|
||||||
CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
|
CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
|
||||||
|
const {
|
||||||
return cost_analysis_.GetMemoryBoundedness(interval);
|
return cost_analysis_.GetMemoryBoundedness(interval);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -733,9 +734,9 @@ void AlternateMemoryBestFitHeap::FindAliases(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
|
std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
|
||||||
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
|
||||||
std::vector<const BufferInterval*> colocated_intervals;
|
std::vector<const BufferInterval*> colocated_intervals;
|
||||||
std::vector<const BufferInterval*> worklist = {&interval};
|
std::vector<const BufferInterval*> worklist = {&interval};
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
@ -864,7 +865,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
|
void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
const AlternateMemoryBestFitHeap::BufferInterval& interval,
|
||||||
std::string* debug_str) const {
|
std::string* debug_str) const {
|
||||||
// Columns in buffer information:
|
// Columns in buffer information:
|
||||||
// buffer_id: int. This value can be used to match the allocation in
|
// buffer_id: int. This value can be used to match the allocation in
|
||||||
@ -954,7 +955,7 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
|
|||||||
options_.dump_fn("allocinfo", allocation_info_str_);
|
options_.dump_fn("allocinfo", allocation_info_str_);
|
||||||
}
|
}
|
||||||
|
|
||||||
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
|
||||||
std::vector<BufferInterval> sorted_buffer_intervals =
|
std::vector<BufferInterval> sorted_buffer_intervals =
|
||||||
GetSortedBufferIntervals();
|
GetSortedBufferIntervals();
|
||||||
|
|
||||||
@ -1390,10 +1391,10 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
|
|||||||
MemorySpaceAssignment::Allocation* last_allocation =
|
MemorySpaceAssignment::Allocation* last_allocation =
|
||||||
allocations_->at(1).get();
|
allocations_->at(1).get();
|
||||||
CHECK(last_allocation->memory_space() == MemorySpace::kAlternate);
|
CHECK(last_allocation->memory_space() == MemorySpace::kAlternate);
|
||||||
repack_allocation_blocks_.push_back(RepackAllocationBlock(
|
repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
|
||||||
last_allocation->start_time(), last_allocation->end_time(),
|
last_allocation->start_time(), last_allocation->end_time(),
|
||||||
last_allocation->chunk().size, last_allocation->chunk().offset,
|
last_allocation->chunk().size, last_allocation->chunk().offset,
|
||||||
last_allocation));
|
static_cast<int64>(repack_allocation_blocks_.size()), last_allocation));
|
||||||
repack_allocation_blocks_.back().colocations.push_back(
|
repack_allocation_blocks_.back().colocations.push_back(
|
||||||
&repack_allocation_blocks_.back());
|
&repack_allocation_blocks_.back());
|
||||||
|
|
||||||
@ -1671,10 +1672,12 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations(
|
|||||||
std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
|
std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
|
||||||
for (MemorySpaceAssignment::Allocation* colocated_allocation :
|
for (MemorySpaceAssignment::Allocation* colocated_allocation :
|
||||||
colocation.second) {
|
colocation.second) {
|
||||||
repack_allocation_blocks_.push_back(RepackAllocationBlock(
|
repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
|
||||||
colocated_allocation->start_time(), colocated_allocation->end_time(),
|
colocated_allocation->start_time(), colocated_allocation->end_time(),
|
||||||
colocated_allocation->chunk().size,
|
colocated_allocation->chunk().size,
|
||||||
colocated_allocation->chunk().offset, colocated_allocation));
|
colocated_allocation->chunk().offset,
|
||||||
|
static_cast<int64>(repack_allocation_blocks_.size()),
|
||||||
|
colocated_allocation));
|
||||||
colocations.push_back(&repack_allocation_blocks_.back());
|
colocations.push_back(&repack_allocation_blocks_.back());
|
||||||
}
|
}
|
||||||
for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
|
for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
|
||||||
@ -2369,8 +2372,8 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
|||||||
return x_memory_boundedness > y_memory_boundedness;
|
return x_memory_boundedness > y_memory_boundedness;
|
||||||
}
|
}
|
||||||
// Tie-break if the memory boundedness is the same.
|
// Tie-break if the memory boundedness is the same.
|
||||||
return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()(
|
return GlobalDecreasingSizeBestFitHeap<
|
||||||
x, y);
|
HloValue>::GetSpatialBufferIntervalCompare()(x, y);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class MemorySpaceAssignmentCostAnalysis {
|
|||||||
// BufferInterval. The larger this number, the higher priority it will be
|
// BufferInterval. The larger this number, the higher priority it will be
|
||||||
// placed in the alternate memory.
|
// placed in the alternate memory.
|
||||||
float GetMemoryBoundedness(
|
float GetMemoryBoundedness(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
|
||||||
Cache* cache = nullptr) const;
|
Cache* cache = nullptr) const;
|
||||||
|
|
||||||
// Returns the elapsed time in seconds due to compute only.
|
// Returns the elapsed time in seconds due to compute only.
|
||||||
@ -235,7 +235,8 @@ class PrefetchIntervalPicker {
|
|||||||
// of placing the BufferInterval in the alternate memory. The larger value,
|
// of placing the BufferInterval in the alternate memory. The larger value,
|
||||||
// the more beneficial.
|
// the more beneficial.
|
||||||
virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
|
||||||
|
const {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -324,7 +325,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
|||||||
int64 end_time) const override;
|
int64 end_time) const override;
|
||||||
|
|
||||||
absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval)
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
|
||||||
const override;
|
const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -370,9 +371,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
|||||||
class MemorySpaceAssignment {
|
class MemorySpaceAssignment {
|
||||||
public:
|
public:
|
||||||
using Chunk = HeapSimulator::Chunk;
|
using Chunk = HeapSimulator::Chunk;
|
||||||
using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval;
|
using BufferInterval =
|
||||||
|
GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval;
|
||||||
using BufferIntervalCompare =
|
using BufferIntervalCompare =
|
||||||
GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare;
|
GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare;
|
||||||
using IsAllowedInAlternateMemoryFunction =
|
using IsAllowedInAlternateMemoryFunction =
|
||||||
std::function<bool(const HloValue&)>;
|
std::function<bool(const HloValue&)>;
|
||||||
|
|
||||||
@ -913,7 +915,8 @@ class AsynchronousCopyOrdering {
|
|||||||
|
|
||||||
// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
|
// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
|
||||||
// maximum size.
|
// maximum size.
|
||||||
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
class AlternateMemoryBestFitHeap
|
||||||
|
: public GlobalDecreasingSizeBestFitHeap<HloValue> {
|
||||||
public:
|
public:
|
||||||
using MemorySpace = MemorySpaceAssignment::MemorySpace;
|
using MemorySpace = MemorySpaceAssignment::MemorySpace;
|
||||||
using AllocationValue = MemorySpaceAssignment::AllocationValue;
|
using AllocationValue = MemorySpaceAssignment::AllocationValue;
|
||||||
@ -940,25 +943,13 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
void AllocateCrossProgramPrefetchBuffer(
|
void AllocateCrossProgramPrefetchBuffer(
|
||||||
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
|
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
|
||||||
|
|
||||||
HeapSimulator::Result Finish() override;
|
HeapSimulator::Result<HloValue> Finish() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// We inherit AllocationBlock struct to attach the Allocation information to
|
// We inherit AllocationBlock struct to attach the Allocation information to
|
||||||
// make importing repacked offsets easier.
|
// make importing repacked offsets easier.
|
||||||
struct RepackAllocationBlock
|
struct RepackAllocationBlock
|
||||||
: MemorySpaceAssignmentRepacker::AllocationBlock {
|
: MemorySpaceAssignmentRepacker::AllocationBlock {
|
||||||
RepackAllocationBlock(int64 start_time, int64 end_time, int64 size,
|
|
||||||
int64 initial_offset,
|
|
||||||
MemorySpaceAssignment::Allocation* allocation) {
|
|
||||||
this->start_time = start_time;
|
|
||||||
this->end_time = end_time;
|
|
||||||
this->size = size;
|
|
||||||
this->offset = -1;
|
|
||||||
this->initial_offset = initial_offset;
|
|
||||||
this->colocations = {};
|
|
||||||
this->allocation = allocation;
|
|
||||||
}
|
|
||||||
|
|
||||||
MemorySpaceAssignment::Allocation* allocation;
|
MemorySpaceAssignment::Allocation* allocation;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1231,6 +1222,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
return options_.max_size_in_bytes - reserved_in_bytes_;
|
return options_.max_size_in_bytes - reserved_in_bytes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates and returns a RepackAllocationBlock.
|
||||||
|
static RepackAllocationBlock MakeRepackAllocationBlock(
|
||||||
|
int64 start_time, int64 end_time, int64 size, int64 initial_offset,
|
||||||
|
int64 id, MemorySpaceAssignment::Allocation* allocation) {
|
||||||
|
RepackAllocationBlock allocation_block;
|
||||||
|
allocation_block.start_time = start_time;
|
||||||
|
allocation_block.end_time = end_time;
|
||||||
|
allocation_block.size = size;
|
||||||
|
allocation_block.offset = -1;
|
||||||
|
allocation_block.initial_offset = initial_offset;
|
||||||
|
allocation_block.id = id;
|
||||||
|
allocation_block.colocations = {};
|
||||||
|
allocation_block.allocation = allocation;
|
||||||
|
return allocation_block;
|
||||||
|
}
|
||||||
|
|
||||||
MemorySpaceAssignment::AllocationSequence* allocations_;
|
MemorySpaceAssignment::AllocationSequence* allocations_;
|
||||||
const MemorySpaceAssignment::Options& options_;
|
const MemorySpaceAssignment::Options& options_;
|
||||||
const HloAliasAnalysis& alias_analysis_;
|
const HloAliasAnalysis& alias_analysis_;
|
||||||
|
@ -33,14 +33,26 @@ class MemorySpaceAssignmentRepacker {
|
|||||||
// successful and the allocations were modified, the offset field holds the
|
// successful and the allocations were modified, the offset field holds the
|
||||||
// new offset. To support aliased allocations, AllocationBlock also includes a
|
// new offset. To support aliased allocations, AllocationBlock also includes a
|
||||||
// vector of AllocationBlock pointers, called colocations. All AllocationBlock
|
// vector of AllocationBlock pointers, called colocations. All AllocationBlock
|
||||||
// objects within the colocations must get the same offset.
|
// objects within the colocations must get the same offset. The id should be
|
||||||
|
// unique and is used to ensure determinism for comparison tie-breaker.
|
||||||
struct AllocationBlock {
|
struct AllocationBlock {
|
||||||
int64 start_time;
|
int64 start_time;
|
||||||
int64 end_time;
|
int64 end_time;
|
||||||
int64 size;
|
int64 size;
|
||||||
int64 offset;
|
int64 offset;
|
||||||
int64 initial_offset;
|
int64 initial_offset;
|
||||||
|
int64 id;
|
||||||
std::vector<AllocationBlock*> colocations;
|
std::vector<AllocationBlock*> colocations;
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
return absl::StrCat("[", start_time, ", ", end_time, "] : size = ", size,
|
||||||
|
", offset = ", offset,
|
||||||
|
" initial offset = ", initial_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is required by BufferIntervalCompare as a tie breaker. Use a unique
|
||||||
|
// and deterministic id.
|
||||||
|
bool operator<(const AllocationBlock& other) const { return id < other.id; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Repack the AllocationBlocks provided in the parameter. Returns true if
|
// Repack the AllocationBlocks provided in the parameter. Returns true if
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
|
||||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||||
// that are pointed to by the tuple will still use this algorithm. Because
|
// that are pointed to by the tuple will still use this algorithm. Because
|
||||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||||
|
@ -26,7 +26,8 @@ class MemorySpaceAssignmentUtils {
|
|||||||
// Returns true if this buffer is allowed to be placed in the alternate
|
// Returns true if this buffer is allowed to be placed in the alternate
|
||||||
// memory.
|
// memory.
|
||||||
static bool IsIntervalAllowedInAlternateMemory(
|
static bool IsIntervalAllowedInAlternateMemory(
|
||||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval);
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval&
|
||||||
|
interval);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user