[XLA] When allocating for later uses, use an use_times dict to improve compile time complexity.
We were iterating the last use time one by one. This CL makes it more efficient by finding longer allocations that are actually used. PiperOrigin-RevId: 299396736 Change-Id: Iac245018bb53b8d8ea346474321bab3bd7d2909d
This commit is contained in:
parent
950b054440
commit
423c2cae26
@ -424,10 +424,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
aliased_allocation->chunk(), definition_time, definition_time));
|
||||
}
|
||||
|
||||
std::vector<int64> use_times(uses.size());
|
||||
for (int i = 0; i < uses.size(); ++i) {
|
||||
use_times[i] = instruction_schedule.at(uses[i].instruction);
|
||||
}
|
||||
// Iterate over the uses.
|
||||
for (HloUse use : uses) {
|
||||
int64 use_time = instruction_schedule.at(use.instruction);
|
||||
int64 last_use_time = instruction_schedule.at(uses.back().instruction);
|
||||
int64 latest_prefetch_time = use_time;
|
||||
|
||||
if (use.instruction->parent() != defining_computation) {
|
||||
@ -457,7 +460,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
AllocationRequest request;
|
||||
request.start_time = definition_time;
|
||||
request.end_time = use_time;
|
||||
request.last_use_time = last_use_time;
|
||||
request.use_times = &use_times;
|
||||
request.latest_prefetch_time = latest_prefetch_time;
|
||||
request.use = use;
|
||||
request.buffer = value;
|
||||
@ -692,7 +695,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
|
||||
VLOG(2) << "Finding allocation for " << request.buffer->ToShortString()
|
||||
<< " (" << request.start_time << ", " << request.end_time
|
||||
<< ") latest prefetch = " << request.latest_prefetch_time
|
||||
<< " last use = " << request.last_use_time
|
||||
<< " last use = " << request.use_times->back()
|
||||
<< " use = " << request.use.ToString() << ". Size = " << request.size
|
||||
<< ", def pos = " << defining_position.ToString();
|
||||
CHECK_LE(request.start_time, request.end_time);
|
||||
@ -880,7 +883,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
|
||||
// the last use time, we try to find an allocation that is available for the
|
||||
// entire Producer to Use2 range.
|
||||
absl::optional<ChunkCandidate> chunk_candidate =
|
||||
FindBestChunkCandidate(request.end_time, request.last_use_time,
|
||||
FindBestChunkCandidate(request.end_time, *request.use_times,
|
||||
preferred_offset, &alternate_mem_interval);
|
||||
// Check if the new heap size fits within limits. Also ensure if a
|
||||
// preferred offset was provided, that offset was used.
|
||||
@ -1045,7 +1048,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
}
|
||||
|
||||
auto chunk_candidate = FindBestChunkCandidate(
|
||||
request.end_time, request.last_use_time,
|
||||
request.end_time, *request.use_times,
|
||||
/*preferred_offset=*/absl::nullopt, &alternate_mem_interval);
|
||||
// Check if we could find a suitable chunk.
|
||||
if (chunk_candidate) {
|
||||
@ -1072,13 +1075,15 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
|
||||
absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
|
||||
AlternateMemoryBestFitHeap::FindBestChunkCandidate(
|
||||
int64 end_time, int64 last_use_time, absl::optional<int64> preferred_offset,
|
||||
int64 end_time, const std::vector<int64>& use_times,
|
||||
absl::optional<int64> preferred_offset,
|
||||
BufferInterval* alternate_mem_interval) const {
|
||||
if (!preferred_offset) {
|
||||
// Find a chunk that's as long living as possible.
|
||||
for (alternate_mem_interval->end = last_use_time;
|
||||
alternate_mem_interval->end >= end_time;
|
||||
--alternate_mem_interval->end) {
|
||||
// Find a chunk that's as long living as possible iterating in reverse over
|
||||
// the use times.
|
||||
for (auto use_time = use_times.rbegin();
|
||||
use_time != use_times.rend() && *use_time >= end_time; ++use_time) {
|
||||
alternate_mem_interval->end = *use_time;
|
||||
ChunkCandidate chunk_candidate =
|
||||
FindChunkCandidate(*alternate_mem_interval);
|
||||
if (chunk_candidate.heap_size <= available_heap_size()) {
|
||||
|
@ -639,13 +639,13 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
// Segment Segment Segment
|
||||
//
|
||||
// start_time and end_time are the start and end logical times of the segment.
|
||||
// last_use_time is the time of the last use for this buffer (Use3 in the
|
||||
// figure). latest_prefetch_time is the latest time we can schedule the
|
||||
// CopyDone for a prefetch.
|
||||
// use_times is a sorted sequence of the times of all uses.
|
||||
// latest_prefetch_time is the latest time we can schedule the CopyDone for a
|
||||
// prefetch.
|
||||
struct AllocationRequest {
|
||||
int64 start_time;
|
||||
int64 end_time;
|
||||
int64 last_use_time;
|
||||
const std::vector<int64>* use_times;
|
||||
int64 latest_prefetch_time;
|
||||
int64 size;
|
||||
HloUse use;
|
||||
@ -700,7 +700,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
// availability if no preferred offset is given, or at the preferred_offset if
|
||||
// it is given.
|
||||
absl::optional<ChunkCandidate> FindBestChunkCandidate(
|
||||
int64 end_time, int64 last_use_time,
|
||||
int64 end_time, const std::vector<int64>& use_times,
|
||||
absl::optional<int64> preferred_offset,
|
||||
BufferInterval* alternate_mem_interval) const;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user