[XLA] When checking async copy ordering, also check for pending copies.

PiperOrigin-RevId: 291286705
Change-Id: I6fb6303c6ad92c0f7de8a2aa198196f376d96f00
This commit is contained in:
Berkin Ilbeyi 2020-01-23 18:10:20 -08:00 committed by TensorFlower Gardener
parent 10c964380c
commit 4ae9c44fa5
2 changed files with 22 additions and 2 deletions

View File

@ -859,8 +859,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
VLOG(4) << "This would violate the outstanding async copy limit.";
continue;
}
if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start,
alternate_mem_interval.end)) {
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
alternate_mem_interval.end)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
continue;
}
@ -937,6 +937,23 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
return num_async_copies + 1 > options_.max_outstanding_async_copies;
}
bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
int64 start_time, int64 end_time) const {
if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
return true;
}
// Also check pending async copies.
for (const auto& async_copy : pending_async_copies_) {
if (async_copy.destination == MemorySpace::kAlternate &&
async_copy.start_time <= end_time &&
start_time <= async_copy.end_time) {
return true;
}
}
return false;
}
bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use,

View File

@ -688,6 +688,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time,
int64 end_time) const;
// Return true if the asynchronous copy would violate the pipelining order.
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
// Adds an asynchronous copy to the allocations.
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time,