diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 8d9510ffae9..21b222609c6 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -859,8 +859,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation( VLOG(4) << "This would violate the outstanding async copy limit."; continue; } - if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start, - alternate_mem_interval.end)) { + if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start, + alternate_mem_interval.end)) { VLOG(4) << "This would violate asynchronous copy ordering."; continue; } @@ -937,6 +937,23 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( return num_async_copies + 1 > options_.max_outstanding_async_copies; } +bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering( + int64 start_time, int64 end_time) const { + if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) { + return true; + } + + // Also check pending async copies. + for (const auto& async_copy : pending_async_copies_) { + if (async_copy.destination == MemorySpace::kAlternate && + async_copy.start_time <= end_time && + start_time <= async_copy.end_time) { + return true; + } + } + return false; +} + bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( int64 start_time, int64 end_time, int64 last_use_time, HloPosition defining_position, HloUse use, diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 9bf04a0fbb5..53fe2e4f197 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -688,6 +688,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, int64 end_time) const; + // Return true if the asynchronous copy would violate the pipelining order. + bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const; + // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, Chunk chunk, int64 start_time,