[XLA] Allow copy-done to be scheduled earlier to avoid copy ordering issues.
We previously required that copy-dones to be scheduled right before the use. We also require asynchronous copies to maintain pipelining order (no nested copy-start/copy-done pairs). This could mean that a smaller buffer that has shorter copy-start/copy-done duration may block a larger buffer due to copy ordering. E.g. this situation might arise which is not allowed: small tensor in default mem---------------------->CS----->CD->use large tensor in default mem------------------>CS----------------->CD->use ====================================================================> time This CL checks if there is an already committed asynchronous copy that violates the pipelining behavior. If so, we attempt to move the copy-done earlier: small tensor in default mem---------------------->CS----->CD->use large tensor in default mem-------->CS----------------->CD----------->use ====================================================================> time PiperOrigin-RevId: 322621813 Change-Id: I8287c11c96a6d71de86a5f2e22cf1846c26ef4f3
This commit is contained in:
parent
30885b432e
commit
70edbdb6c7
@ -235,6 +235,11 @@ int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
|
|||||||
return std::min(start_time + min_overlap_count_, latest_end_time);
|
return std::min(start_time + min_overlap_count_, latest_end_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
|
||||||
|
const HloUse& use, int64 start_time, int64 end_time) const {
|
||||||
|
return end_time_ - min_overlap_count_;
|
||||||
|
}
|
||||||
|
|
||||||
void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
|
void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
|
||||||
int64 start_time,
|
int64 start_time,
|
||||||
int64 end_time) {
|
int64 end_time) {
|
||||||
@ -355,6 +360,49 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
|
|||||||
return end_time;
|
return end_time;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
|
||||||
|
const HloUse& use, int64 start_time, int64 end_time) const {
|
||||||
|
const Shape& shape = ShapeUtil::GetSubshape(
|
||||||
|
use.instruction->operand(use.operand_number)->shape(), use.operand_index);
|
||||||
|
// Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
|
||||||
|
float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
|
||||||
|
// Estimate the time we would save by having this op in alternate memory.
|
||||||
|
float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
|
||||||
|
float elapsed_time_in_alternate_mem =
|
||||||
|
cost_analysis_.GetInstructionElapsedInAlternateMemory(
|
||||||
|
*use.instruction, use.operand_number,
|
||||||
|
/*output_in_alternate_mem=*/false);
|
||||||
|
float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
|
||||||
|
int end_nest_level = while_nest_level_[end_time];
|
||||||
|
|
||||||
|
// Find the latest time we're allowed to start prefetching.
|
||||||
|
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
|
||||||
|
int latest_prefetch_time;
|
||||||
|
for (latest_prefetch_time = end_time - 1;
|
||||||
|
latest_prefetch_time >= start_time &&
|
||||||
|
(while_nest_level_[latest_prefetch_time] != end_nest_level ||
|
||||||
|
min_interval >
|
||||||
|
GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
|
||||||
|
inst_elapsed_reduction);
|
||||||
|
--latest_prefetch_time) {
|
||||||
|
}
|
||||||
|
|
||||||
|
return latest_prefetch_time;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
|
||||||
|
int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const {
|
||||||
|
// Iterate towards the beginning until we find a suitable end time that is the
|
||||||
|
// same while nest level as the original prefetch end time.
|
||||||
|
int64 original_nest_level = while_nest_level_[original_prefetch_end_time];
|
||||||
|
int64 new_prefetch_end_time;
|
||||||
|
for (new_prefetch_end_time = proposed_prefetch_end_time;
|
||||||
|
while_nest_level_[new_prefetch_end_time] != original_nest_level;
|
||||||
|
--new_prefetch_end_time) {
|
||||||
|
}
|
||||||
|
return new_prefetch_end_time;
|
||||||
|
}
|
||||||
|
|
||||||
void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
|
void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
|
||||||
int64 start_time,
|
int64 start_time,
|
||||||
int64 end_time) {
|
int64 end_time) {
|
||||||
@ -374,14 +422,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
|
|||||||
|
|
||||||
// Find the latest time we're allowed to start prefetching.
|
// Find the latest time we're allowed to start prefetching.
|
||||||
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
|
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
|
||||||
for (latest_prefetch_time_ = end_logical_time_ - 1;
|
latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time);
|
||||||
latest_prefetch_time_ >= start_time &&
|
|
||||||
(while_nest_level_[latest_prefetch_time_] != end_nest_level ||
|
|
||||||
min_interval > GetLogicalIntervalElapsed(latest_prefetch_time_,
|
|
||||||
end_logical_time_) +
|
|
||||||
inst_elapsed_reduction_);
|
|
||||||
--latest_prefetch_time_) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the earliest time we're allowed to start prefetching.
|
// Find the earliest time we're allowed to start prefetching.
|
||||||
float max_interval = max_async_copy_to_overlap_ratio_ *
|
float max_interval = max_async_copy_to_overlap_ratio_ *
|
||||||
@ -1229,15 +1270,21 @@ void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
|
|||||||
ranges_.erase(copy_it);
|
ranges_.erase(copy_it);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
|
absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
|
||||||
int64 end_time) const {
|
int64 start_time, int64 end_time) const {
|
||||||
// We allow identical start and end times. It is enough to check for just the
|
// We allow identical start and end times. It is enough to check for just the
|
||||||
// start time in case we find a match in ranges_ because the found value will
|
// start time in case we find a match in ranges_ because the found value will
|
||||||
// either be identical to {start_time, end_time} (and this doesn't violate) or
|
// either be identical to {start_time, end_time} (and this doesn't violate) or
|
||||||
// its start_time will be smaller and end_time will be larger (this violates).
|
// its start_time will be smaller and end_time will be larger (this violates).
|
||||||
auto copy_it = ranges_.find(
|
auto copy_it = ranges_.find(
|
||||||
{start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
|
{start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
|
||||||
return copy_it != ranges_.end() && copy_it->start_time != start_time;
|
if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
|
||||||
|
VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
|
||||||
|
<< ") and (" << copy_it->start_time << ", " << copy_it->end_time
|
||||||
|
<< ")";
|
||||||
|
return *copy_it;
|
||||||
|
}
|
||||||
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ MemorySpaceAssignment::Allocation*
|
/*static*/ MemorySpaceAssignment::Allocation*
|
||||||
@ -1734,8 +1781,9 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
|
absl::optional<AsynchronousCopy>
|
||||||
int64 start_time, int64 end_time) const {
|
AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time,
|
||||||
|
int64 end_time) const {
|
||||||
return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
|
return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1945,6 +1993,50 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
|
||||||
|
const AllocationRequest& request, int64 earliest_prefetch_time) const {
|
||||||
|
int64 prefetch_end_time = request.latest_prefetch_time;
|
||||||
|
|
||||||
|
for (int retry_number = 0;
|
||||||
|
retry_number < options_.prefetch_copy_done_reorder_max_retries;
|
||||||
|
++retry_number) {
|
||||||
|
int64 latest_prefetch_time =
|
||||||
|
options_.prefetch_interval_picker->LatestPrefetchStartTime(
|
||||||
|
request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
|
||||||
|
// Return if we couldn't find a suitable prefetch start time.
|
||||||
|
if (latest_prefetch_time < earliest_prefetch_time) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return either if there is no other violating asynchronous copy (since we
|
||||||
|
// don't need to change the prefetch end time) or if the violating
|
||||||
|
// asynchronous copy ends after the prefetch end time.
|
||||||
|
auto violating_async_copy =
|
||||||
|
ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
|
||||||
|
if (!violating_async_copy ||
|
||||||
|
violating_async_copy->end_time >= prefetch_end_time) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
|
||||||
|
<< ", " << violating_async_copy->end_time << ")";
|
||||||
|
|
||||||
|
int64 new_prefetch_end_time =
|
||||||
|
options_.prefetch_interval_picker->LatestPrefetchEndTime(
|
||||||
|
prefetch_end_time, violating_async_copy->end_time);
|
||||||
|
if (new_prefetch_end_time > earliest_prefetch_time) {
|
||||||
|
VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
|
||||||
|
prefetch_end_time = new_prefetch_end_time;
|
||||||
|
} else {
|
||||||
|
VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
|
||||||
|
<< " because earliest prefetch start time = "
|
||||||
|
<< earliest_prefetch_time;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefetch_end_time;
|
||||||
|
}
|
||||||
|
|
||||||
bool AlternateMemoryBestFitHeap::Prefetch(
|
bool AlternateMemoryBestFitHeap::Prefetch(
|
||||||
const AllocationRequest& request,
|
const AllocationRequest& request,
|
||||||
const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
|
const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
|
||||||
@ -1966,9 +2058,11 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
earliest_prefetch_time =
|
earliest_prefetch_time =
|
||||||
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
|
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
|
||||||
}
|
}
|
||||||
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
|
int64 prefetch_end_time =
|
||||||
earliest_prefetch_time,
|
FindPrefetchEndTime(request, earliest_prefetch_time);
|
||||||
request.latest_prefetch_time);
|
|
||||||
|
options_.prefetch_interval_picker->Begin(
|
||||||
|
request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
|
||||||
VLOG(3) << "Trying prefetch picker = "
|
VLOG(3) << "Trying prefetch picker = "
|
||||||
<< options_.prefetch_interval_picker->ToDebugString();
|
<< options_.prefetch_interval_picker->ToDebugString();
|
||||||
|
|
||||||
@ -1988,19 +2082,19 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
: 0;
|
: 0;
|
||||||
while (!options_.prefetch_interval_picker->Done()) {
|
while (!options_.prefetch_interval_picker->Done()) {
|
||||||
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
||||||
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
|
CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
|
||||||
VLOG(4) << "Trying alternate memory allocation ("
|
VLOG(4) << "Trying alternate memory allocation ("
|
||||||
<< alternate_mem_interval.start << ", " << request.end_time << ")";
|
<< alternate_mem_interval.start << ", " << request.end_time << ")";
|
||||||
// If this additional asynchronous copy would violate the limit, try a
|
// If this additional asynchronous copy would violate the limit, try a
|
||||||
// different interval.
|
// different interval.
|
||||||
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
|
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
|
||||||
request.latest_prefetch_time)) {
|
prefetch_end_time)) {
|
||||||
VLOG(4) << "This would violate asynchronous copy ordering.";
|
VLOG(4) << "This would violate asynchronous copy ordering.";
|
||||||
prefetch_failed_due_to_async_copy_ = true;
|
prefetch_failed_due_to_async_copy_ = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (ViolatesMaximumOutstandingAsyncCopies(
|
if (ViolatesMaximumOutstandingAsyncCopies(
|
||||||
alternate_mem_interval.start, request.latest_prefetch_time,
|
alternate_mem_interval.start, prefetch_end_time,
|
||||||
/*is_prefetch=*/true, extra_async_copy_limit)) {
|
/*is_prefetch=*/true, extra_async_copy_limit)) {
|
||||||
VLOG(4) << "This would violate the outstanding async copy limit.";
|
VLOG(4) << "This would violate the outstanding async copy limit.";
|
||||||
prefetch_failed_due_to_async_copy_ = true;
|
prefetch_failed_due_to_async_copy_ = true;
|
||||||
@ -2022,7 +2116,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
|||||||
|
|
||||||
AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
|
AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
|
||||||
chunk_candidate->chunk, alternate_mem_interval.start,
|
chunk_candidate->chunk, alternate_mem_interval.start,
|
||||||
request.end_time, request.latest_prefetch_time,
|
request.end_time, prefetch_end_time,
|
||||||
request.allocation_value->allocation_sequence());
|
request.allocation_value->allocation_sequence());
|
||||||
|
|
||||||
request.allocation_value->allocation_sequence()->back()->AddUse(
|
request.allocation_value->allocation_sequence()->back()->AddUse(
|
||||||
|
@ -198,6 +198,17 @@ class PrefetchIntervalPicker {
|
|||||||
virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
||||||
int64 latest_end_time) const = 0;
|
int64 latest_end_time) const = 0;
|
||||||
|
|
||||||
|
// Returns the latest time that a prefetch can start.
|
||||||
|
virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
|
||||||
|
int64 end_time) const = 0;
|
||||||
|
|
||||||
|
// Returns the latest time that a prefetch can end that is less than or equal
|
||||||
|
// to proposed_prefetch_end_time.
|
||||||
|
virtual int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
|
||||||
|
int64 proposed_prefetch_end_time) const {
|
||||||
|
return proposed_prefetch_end_time;
|
||||||
|
}
|
||||||
|
|
||||||
// Begins the iterator for the first start time of the prefetch.
|
// Begins the iterator for the first start time of the prefetch.
|
||||||
virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
|
virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
|
||||||
|
|
||||||
@ -256,6 +267,9 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
|||||||
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
||||||
int64 latest_end_time) const override;
|
int64 latest_end_time) const override;
|
||||||
|
|
||||||
|
int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
|
||||||
|
int64 end_time) const override;
|
||||||
|
|
||||||
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
|
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
|
||||||
|
|
||||||
int64 Next() override;
|
int64 Next() override;
|
||||||
@ -292,6 +306,11 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
|||||||
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
||||||
int64 latest_end_time) const override;
|
int64 latest_end_time) const override;
|
||||||
|
|
||||||
|
int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
|
||||||
|
int64 end_time) const override;
|
||||||
|
int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
|
||||||
|
int64 proposed_prefetch_end_time) const override;
|
||||||
|
|
||||||
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
|
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
|
||||||
|
|
||||||
int64 Next() override;
|
int64 Next() override;
|
||||||
@ -395,6 +414,11 @@ class MemorySpaceAssignment {
|
|||||||
// max_outstanding_prefetches).
|
// max_outstanding_prefetches).
|
||||||
int64 while_use_extra_outstanding_prefetch_limit = 0;
|
int64 while_use_extra_outstanding_prefetch_limit = 0;
|
||||||
|
|
||||||
|
// Specifies the maximum number of times we are willing to move a copy
|
||||||
|
// done of a prefetch earlier due to an asynchronous copy ordering
|
||||||
|
// violation.
|
||||||
|
int64 prefetch_copy_done_reorder_max_retries = 1;
|
||||||
|
|
||||||
// Specifies the maximum number of retries that will be performed for each
|
// Specifies the maximum number of retries that will be performed for each
|
||||||
// value in case prefetching failed due to running out of asynchronous
|
// value in case prefetching failed due to running out of asynchronous
|
||||||
// copies or asynchronous copy ordering.
|
// copies or asynchronous copy ordering.
|
||||||
@ -850,9 +874,9 @@ class AsynchronousCopyOrdering {
|
|||||||
// Removes an asynchronous copy. CHECKs that it is removed.
|
// Removes an asynchronous copy. CHECKs that it is removed.
|
||||||
void RemoveCopy(const AsynchronousCopy& copy);
|
void RemoveCopy(const AsynchronousCopy& copy);
|
||||||
|
|
||||||
// Returns true if the addition of an asynchronous copy in the the given time
|
// If the addition of an asynchronous copy in the given time interval would
|
||||||
// interval would violate the asynchronous copy ordering. E.g., consider the
|
// violate the asynchronous copy ordering, returns the violating
|
||||||
// following scenario:
|
// already-committed asynchronous copy. E.g., consider the following scenario:
|
||||||
// CS CD
|
// CS CD
|
||||||
// already committed async copy: +-----------+
|
// already committed async copy: +-----------+
|
||||||
// new async copy: +--------+
|
// new async copy: +--------+
|
||||||
@ -860,7 +884,8 @@ class AsynchronousCopyOrdering {
|
|||||||
// The new asynchronous copy would violate the ordering guarantee because the
|
// The new asynchronous copy would violate the ordering guarantee because the
|
||||||
// copy start is after an already committed asynchronous copy while its copy
|
// copy start is after an already committed asynchronous copy while its copy
|
||||||
// done is before the committed copy.
|
// done is before the committed copy.
|
||||||
bool ViolatesOrdering(int64 start_time, int64 end_time) const;
|
absl::optional<AsynchronousCopy> ViolatesOrdering(int64 start_time,
|
||||||
|
int64 end_time) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Stores asynchronous copies in a tree set respecting the pipelining order.
|
// Stores asynchronous copies in a tree set respecting the pipelining order.
|
||||||
@ -981,6 +1006,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
// Try evicting to default memory space. Returns true if successful.
|
// Try evicting to default memory space. Returns true if successful.
|
||||||
bool Evict(const AllocationRequest& request);
|
bool Evict(const AllocationRequest& request);
|
||||||
|
|
||||||
|
// Returns the time a copy done of a prefetch should be scheduled.
|
||||||
|
int64 FindPrefetchEndTime(const AllocationRequest& request,
|
||||||
|
int64 earliest_prefetch_time) const;
|
||||||
|
|
||||||
// Try prefetching to alternate memory space. Returns true if successful.
|
// Try prefetching to alternate memory space. Returns true if successful.
|
||||||
bool Prefetch(
|
bool Prefetch(
|
||||||
const AllocationRequest& request,
|
const AllocationRequest& request,
|
||||||
@ -1045,8 +1074,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
|||||||
int64 start_time, int64 end_time, bool is_prefetch,
|
int64 start_time, int64 end_time, bool is_prefetch,
|
||||||
int64 extra_async_copy_limit = 0) const;
|
int64 extra_async_copy_limit = 0) const;
|
||||||
|
|
||||||
// Return true if the asynchronous copy would violate the pipelining order.
|
// If the asynchronous copy would violate the pipelining order, returns the
|
||||||
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
|
// violating asynchronous copy.
|
||||||
|
absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering(
|
||||||
|
int64 start_time, int64 end_time) const;
|
||||||
|
|
||||||
// Adds an asynchronous copy to the allocations.
|
// Adds an asynchronous copy to the allocations.
|
||||||
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
||||||
|
@ -286,6 +286,92 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||||||
MemorySpaceAssignmentCostAnalysis::Cache cache_;
|
MemorySpaceAssignmentCostAnalysis::Cache cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// For testing purposes, we define a cost analysis where we can control the
|
||||||
|
// elapsed times of each HLO and asynchronous copy.
|
||||||
|
class FakeMemorySpaceAssignmentCostAnalysis
|
||||||
|
: public MemorySpaceAssignmentCostAnalysis {
|
||||||
|
public:
|
||||||
|
static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
|
||||||
|
Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
|
||||||
|
HloLiveRange::Run(module.schedule(), *alias_analysis,
|
||||||
|
module.entry_computation()));
|
||||||
|
auto call_graph = CallGraph::Build(&module);
|
||||||
|
return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
|
||||||
|
cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
|
||||||
|
/*alternate_mem_bandwidth_bytes_per_second=*/1,
|
||||||
|
std::move(alias_analysis), std::move(hlo_live_range),
|
||||||
|
std::move(call_graph)));
|
||||||
|
}
|
||||||
|
|
||||||
|
float GetInstructionElapsed(
|
||||||
|
const HloInstruction& instruction) const override {
|
||||||
|
if (get_instruction_elapsed_override_) {
|
||||||
|
return get_instruction_elapsed_override_(instruction);
|
||||||
|
}
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GetInstructionElapsedInAlternateMemory(
|
||||||
|
const HloInstruction& instruction,
|
||||||
|
absl::optional<int64> operand_in_alternate_mem,
|
||||||
|
bool output_in_alternate_mem) const override {
|
||||||
|
if (get_instruction_elapsed_in_alternate_memory_override_) {
|
||||||
|
return get_instruction_elapsed_in_alternate_memory_override_(
|
||||||
|
instruction, operand_in_alternate_mem, output_in_alternate_mem);
|
||||||
|
}
|
||||||
|
if (operand_in_alternate_mem) {
|
||||||
|
return 0.5;
|
||||||
|
} else {
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float GetAsyncCopyElapsed(const Shape& shape) const override {
|
||||||
|
if (get_async_copy_elapsed_override_) {
|
||||||
|
return get_async_copy_elapsed_override_(shape);
|
||||||
|
}
|
||||||
|
return 3.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following methods can be used to override what the above API calls
|
||||||
|
// return.
|
||||||
|
void SetOverrideForGetInstructionElapsed(
|
||||||
|
std::function<float(const HloInstruction&)> function) {
|
||||||
|
get_instruction_elapsed_override_ = function;
|
||||||
|
}
|
||||||
|
void SetOverrideForGetInstructionElapsedInAlternateMemory(
|
||||||
|
std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
|
||||||
|
function) {
|
||||||
|
get_instruction_elapsed_in_alternate_memory_override_ = function;
|
||||||
|
}
|
||||||
|
void SetOverrideForGetAsyncCopyElapsed(
|
||||||
|
std::function<float(const Shape&)> function) {
|
||||||
|
get_async_copy_elapsed_override_ = function;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FakeMemorySpaceAssignmentCostAnalysis(
|
||||||
|
const HloCostAnalysis& cost_analysis,
|
||||||
|
float async_copy_bandwidth_bytes_per_second,
|
||||||
|
float alternate_mem_bandwidth_bytes_per_second,
|
||||||
|
std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
|
std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||||
|
std::unique_ptr<CallGraph> call_graph)
|
||||||
|
: MemorySpaceAssignmentCostAnalysis(
|
||||||
|
cost_analysis, async_copy_bandwidth_bytes_per_second,
|
||||||
|
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
|
||||||
|
std::move(hlo_live_range), std::move(call_graph)) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<float(const HloInstruction&)>
|
||||||
|
get_instruction_elapsed_override_ = nullptr;
|
||||||
|
std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
|
||||||
|
get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
|
||||||
|
std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
|
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
|
||||||
// A module consisting of a single parameter. Inputs/outputs are currently
|
// A module consisting of a single parameter. Inputs/outputs are currently
|
||||||
// excluded from memory space assignment.
|
// excluded from memory space assignment.
|
||||||
@ -3750,6 +3836,123 @@ TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
|
|||||||
buffer_interval_compare, &prefetch_interval_picker);
|
buffer_interval_compare, &prefetch_interval_picker);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) {
|
||||||
|
// This tests the case where an earlier placed smaller buffer may block a
|
||||||
|
// larger buffer due to asynchronous copy ordering. The smaller buffer (the
|
||||||
|
// operand of sin) will be placed first. The cos, whose operand is 3 times
|
||||||
|
// larger than sin's, needs longer time for the asynhronous copy. The cos is
|
||||||
|
// placed right after sin, leading to a copy ordering violation:
|
||||||
|
//
|
||||||
|
// param1------------------>CS----->CD->sin
|
||||||
|
// param0------------->CS------------------->CD->cos
|
||||||
|
//
|
||||||
|
// To fix this, we need to move copy done for cos earlier and ensure both of
|
||||||
|
// these buffers get alternate memory allocations:
|
||||||
|
//
|
||||||
|
// param1------------------>CS----->CD->sin
|
||||||
|
// param0-->CS------------------->CD------------>cos
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule module, is_scheduled=true
|
||||||
|
|
||||||
|
ENTRY Entry {
|
||||||
|
param0 = f32[8,3] parameter(0)
|
||||||
|
param1 = f32[2,4] parameter(1)
|
||||||
|
a = f32[2,4] negate(param1)
|
||||||
|
b = f32[2,4] negate(a)
|
||||||
|
c = f32[2,4] negate(b)
|
||||||
|
d = f32[2,4] negate(c)
|
||||||
|
e = f32[2,4] negate(d)
|
||||||
|
f = f32[2,4] negate(e)
|
||||||
|
g = f32[2,4] negate(f)
|
||||||
|
h = f32[2,4] negate(g)
|
||||||
|
i = f32[2,4] negate(h)
|
||||||
|
j = f32[2,4] negate(i)
|
||||||
|
k = f32[2,4] negate(j)
|
||||||
|
l = f32[2,4] negate(k)
|
||||||
|
m = f32[2,4] negate(l)
|
||||||
|
n = f32[2,4] negate(m)
|
||||||
|
sin = f32[2,4] sine(param1)
|
||||||
|
o = f32[2,4] negate(n)
|
||||||
|
cos = f32[8,3] cosine(param0)
|
||||||
|
ROOT tuple = (f32[8,3], f32[2,4], f32[2,4]) tuple(cos, sin, o)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
|
||||||
|
[](const MemorySpaceAssignment::BufferInterval& a,
|
||||||
|
const MemorySpaceAssignment::BufferInterval& b) {
|
||||||
|
auto get_opcode_priority = [](const HloOpcode& opcode) {
|
||||||
|
switch (opcode) {
|
||||||
|
case HloOpcode::kSin:
|
||||||
|
return 0;
|
||||||
|
case HloOpcode::kCos:
|
||||||
|
return 1;
|
||||||
|
case HloOpcode::kTanh:
|
||||||
|
return 2;
|
||||||
|
default:
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_user_priority = [&](const HloValue& value) {
|
||||||
|
int priority = INT_MAX;
|
||||||
|
for (const auto& use : value.uses()) {
|
||||||
|
priority = std::min(priority,
|
||||||
|
get_opcode_priority(use.instruction->opcode()));
|
||||||
|
}
|
||||||
|
return priority;
|
||||||
|
};
|
||||||
|
|
||||||
|
return get_user_priority(*a.buffer) < get_user_priority(*b.buffer);
|
||||||
|
};
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
HloCostAnalysis hlo_cost_analysis(ShapeSize);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
|
||||||
|
FakeMemorySpaceAssignmentCostAnalysis::Create(
|
||||||
|
hlo_cost_analysis, *module));
|
||||||
|
cost_analysis->SetOverrideForGetAsyncCopyElapsed([](const Shape& shape) {
|
||||||
|
// This should return 2 for f32[2,4] and 6 for f32[8,3].
|
||||||
|
return ShapeSize(shape) / 16;
|
||||||
|
});
|
||||||
|
CostAnalysisPrefetchIntervalPicker interval_picker(
|
||||||
|
*cost_analysis,
|
||||||
|
/*min_async_copy_to_overlap_ratio=*/1.0,
|
||||||
|
/*max_async_copy_to_overlap_ratio=*/4.0,
|
||||||
|
/*preferred_async_copy_to_overlap_ratio=*/1.5);
|
||||||
|
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||||
|
buffer_interval_compare, &interval_picker);
|
||||||
|
|
||||||
|
// Check that both cos and sin could get their operands prefetched.
|
||||||
|
const HloInstruction* cos =
|
||||||
|
module->entry_computation()->GetInstructionWithName("cos");
|
||||||
|
const HloInstruction* sin =
|
||||||
|
module->entry_computation()->GetInstructionWithName("sin");
|
||||||
|
EXPECT_THAT(sin->operand(0),
|
||||||
|
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
|
||||||
|
op::Parameter(1)));
|
||||||
|
EXPECT_THAT(cos->operand(0),
|
||||||
|
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
|
||||||
|
op::Parameter(0)));
|
||||||
|
|
||||||
|
// Sanity check that the cos' operand copy-done is scheduled earlier than
|
||||||
|
// sin's operand.
|
||||||
|
auto find_schedule_index = [&](const HloInstruction* instruction) {
|
||||||
|
const auto& instructions =
|
||||||
|
module->schedule().sequence(module->entry_computation()).instructions();
|
||||||
|
for (int i = 0; i < instructions.size(); ++i) {
|
||||||
|
if (instruction == instructions[i]) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK(false);
|
||||||
|
return -1;
|
||||||
|
};
|
||||||
|
EXPECT_GT(find_schedule_index(sin->operand(0)),
|
||||||
|
find_schedule_index(cos->operand(0)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
||||||
// Run memory space assignment a few times to make sure every time it compiles
|
// Run memory space assignment a few times to make sure every time it compiles
|
||||||
// to the same thing.
|
// to the same thing.
|
||||||
@ -4046,57 +4249,6 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
|
|||||||
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For testing purposes, we define a cost analysis where we can control the
|
|
||||||
// elapsed times of each HLO and asynchronous copy.
|
|
||||||
class FakeMemorySpaceAssignmentCostAnalysis
|
|
||||||
: public MemorySpaceAssignmentCostAnalysis {
|
|
||||||
public:
|
|
||||||
static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
|
|
||||||
Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
|
|
||||||
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
|
|
||||||
HloLiveRange::Run(module.schedule(), *alias_analysis,
|
|
||||||
module.entry_computation()));
|
|
||||||
auto call_graph = CallGraph::Build(&module);
|
|
||||||
return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
|
|
||||||
cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
|
|
||||||
/*alternate_mem_bandwidth_bytes_per_second=*/1,
|
|
||||||
std::move(alias_analysis), std::move(hlo_live_range),
|
|
||||||
std::move(call_graph)));
|
|
||||||
}
|
|
||||||
|
|
||||||
float GetInstructionElapsed(
|
|
||||||
const HloInstruction& instruction) const override {
|
|
||||||
return 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
float GetInstructionElapsedInAlternateMemory(
|
|
||||||
const HloInstruction& instruction,
|
|
||||||
absl::optional<int64> operand_in_alternate_mem,
|
|
||||||
bool output_in_alternate_mem) const override {
|
|
||||||
if (operand_in_alternate_mem) {
|
|
||||||
return 0.5;
|
|
||||||
} else {
|
|
||||||
return 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float GetAsyncCopyElapsed(const Shape& shape) const override { return 3.0; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
FakeMemorySpaceAssignmentCostAnalysis(
|
|
||||||
const HloCostAnalysis& cost_analysis,
|
|
||||||
float async_copy_bandwidth_bytes_per_second,
|
|
||||||
float alternate_mem_bandwidth_bytes_per_second,
|
|
||||||
std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
|
||||||
std::unique_ptr<HloLiveRange> hlo_live_range,
|
|
||||||
std::unique_ptr<CallGraph> call_graph)
|
|
||||||
: MemorySpaceAssignmentCostAnalysis(
|
|
||||||
cost_analysis, async_copy_bandwidth_bytes_per_second,
|
|
||||||
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
|
|
||||||
std::move(hlo_live_range), std::move(call_graph)) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
|
using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
|
||||||
|
|
||||||
TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
|
TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user