[XLA] Allow copy-done to be scheduled earlier to avoid copy ordering issues.
We previously required that copy-dones to be scheduled right before the use. We also require asynchronous copies to maintain pipelining order (no nested copy-start/copy-done pairs). This could mean that a smaller buffer that has shorter copy-start/copy-done duration may block a larger buffer due to copy ordering. E.g. this situation might arise which is not allowed: small tensor in default mem---------------------->CS----->CD->use large tensor in default mem------------------>CS----------------->CD->use ====================================================================> time This CL checks if there is an already committed asynchronous copy that violates the pipelining behavior. If so, we attempt to move the copy-done earlier: small tensor in default mem---------------------->CS----->CD->use large tensor in default mem-------->CS----------------->CD----------->use ====================================================================> time PiperOrigin-RevId: 322621813 Change-Id: I8287c11c96a6d71de86a5f2e22cf1846c26ef4f3
This commit is contained in:
parent
30885b432e
commit
70edbdb6c7
tensorflow/compiler/xla/service
@ -235,6 +235,11 @@ int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
|
||||
return std::min(start_time + min_overlap_count_, latest_end_time);
|
||||
}
|
||||
|
||||
int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
|
||||
const HloUse& use, int64 start_time, int64 end_time) const {
|
||||
return end_time_ - min_overlap_count_;
|
||||
}
|
||||
|
||||
void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
|
||||
int64 start_time,
|
||||
int64 end_time) {
|
||||
@ -355,6 +360,49 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
|
||||
return end_time;
|
||||
}
|
||||
|
||||
int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
|
||||
const HloUse& use, int64 start_time, int64 end_time) const {
|
||||
const Shape& shape = ShapeUtil::GetSubshape(
|
||||
use.instruction->operand(use.operand_number)->shape(), use.operand_index);
|
||||
// Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
|
||||
float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
|
||||
// Estimate the time we would save by having this op in alternate memory.
|
||||
float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
|
||||
float elapsed_time_in_alternate_mem =
|
||||
cost_analysis_.GetInstructionElapsedInAlternateMemory(
|
||||
*use.instruction, use.operand_number,
|
||||
/*output_in_alternate_mem=*/false);
|
||||
float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
|
||||
int end_nest_level = while_nest_level_[end_time];
|
||||
|
||||
// Find the latest time we're allowed to start prefetching.
|
||||
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
|
||||
int latest_prefetch_time;
|
||||
for (latest_prefetch_time = end_time - 1;
|
||||
latest_prefetch_time >= start_time &&
|
||||
(while_nest_level_[latest_prefetch_time] != end_nest_level ||
|
||||
min_interval >
|
||||
GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
|
||||
inst_elapsed_reduction);
|
||||
--latest_prefetch_time) {
|
||||
}
|
||||
|
||||
return latest_prefetch_time;
|
||||
}
|
||||
|
||||
int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
|
||||
int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const {
|
||||
// Iterate towards the beginning until we find a suitable end time that is the
|
||||
// same while nest level as the original prefetch end time.
|
||||
int64 original_nest_level = while_nest_level_[original_prefetch_end_time];
|
||||
int64 new_prefetch_end_time;
|
||||
for (new_prefetch_end_time = proposed_prefetch_end_time;
|
||||
while_nest_level_[new_prefetch_end_time] != original_nest_level;
|
||||
--new_prefetch_end_time) {
|
||||
}
|
||||
return new_prefetch_end_time;
|
||||
}
|
||||
|
||||
void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
|
||||
int64 start_time,
|
||||
int64 end_time) {
|
||||
@ -374,14 +422,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
|
||||
|
||||
// Find the latest time we're allowed to start prefetching.
|
||||
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
|
||||
for (latest_prefetch_time_ = end_logical_time_ - 1;
|
||||
latest_prefetch_time_ >= start_time &&
|
||||
(while_nest_level_[latest_prefetch_time_] != end_nest_level ||
|
||||
min_interval > GetLogicalIntervalElapsed(latest_prefetch_time_,
|
||||
end_logical_time_) +
|
||||
inst_elapsed_reduction_);
|
||||
--latest_prefetch_time_) {
|
||||
}
|
||||
latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time);
|
||||
|
||||
// Find the earliest time we're allowed to start prefetching.
|
||||
float max_interval = max_async_copy_to_overlap_ratio_ *
|
||||
@ -1229,15 +1270,21 @@ void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
|
||||
ranges_.erase(copy_it);
|
||||
}
|
||||
|
||||
bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
|
||||
int64 end_time) const {
|
||||
absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
|
||||
int64 start_time, int64 end_time) const {
|
||||
// We allow identical start and end times. It is enough to check for just the
|
||||
// start time in case we find a match in ranges_ because the found value will
|
||||
// either be identical to {start_time, end_time} (and this doesn't violate) or
|
||||
// its start_time will be smaller and end_time will be larger (this violates).
|
||||
auto copy_it = ranges_.find(
|
||||
{start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
|
||||
return copy_it != ranges_.end() && copy_it->start_time != start_time;
|
||||
if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
|
||||
VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
|
||||
<< ") and (" << copy_it->start_time << ", " << copy_it->end_time
|
||||
<< ")";
|
||||
return *copy_it;
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
/*static*/ MemorySpaceAssignment::Allocation*
|
||||
@ -1734,8 +1781,9 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||
}
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
|
||||
int64 start_time, int64 end_time) const {
|
||||
absl::optional<AsynchronousCopy>
|
||||
AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time,
|
||||
int64 end_time) const {
|
||||
return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
|
||||
}
|
||||
|
||||
@ -1945,6 +1993,50 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
|
||||
const AllocationRequest& request, int64 earliest_prefetch_time) const {
|
||||
int64 prefetch_end_time = request.latest_prefetch_time;
|
||||
|
||||
for (int retry_number = 0;
|
||||
retry_number < options_.prefetch_copy_done_reorder_max_retries;
|
||||
++retry_number) {
|
||||
int64 latest_prefetch_time =
|
||||
options_.prefetch_interval_picker->LatestPrefetchStartTime(
|
||||
request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
|
||||
// Return if we couldn't find a suitable prefetch start time.
|
||||
if (latest_prefetch_time < earliest_prefetch_time) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Return either if there is no other violating asynchronous copy (since we
|
||||
// don't need to change the prefetch end time) or if the violating
|
||||
// asynchronous copy ends after the prefetch end time.
|
||||
auto violating_async_copy =
|
||||
ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
|
||||
if (!violating_async_copy ||
|
||||
violating_async_copy->end_time >= prefetch_end_time) {
|
||||
break;
|
||||
}
|
||||
VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
|
||||
<< ", " << violating_async_copy->end_time << ")";
|
||||
|
||||
int64 new_prefetch_end_time =
|
||||
options_.prefetch_interval_picker->LatestPrefetchEndTime(
|
||||
prefetch_end_time, violating_async_copy->end_time);
|
||||
if (new_prefetch_end_time > earliest_prefetch_time) {
|
||||
VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
|
||||
prefetch_end_time = new_prefetch_end_time;
|
||||
} else {
|
||||
VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
|
||||
<< " because earliest prefetch start time = "
|
||||
<< earliest_prefetch_time;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return prefetch_end_time;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
const AllocationRequest& request,
|
||||
const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
|
||||
@ -1966,9 +2058,11 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
earliest_prefetch_time =
|
||||
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
|
||||
}
|
||||
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
|
||||
earliest_prefetch_time,
|
||||
request.latest_prefetch_time);
|
||||
int64 prefetch_end_time =
|
||||
FindPrefetchEndTime(request, earliest_prefetch_time);
|
||||
|
||||
options_.prefetch_interval_picker->Begin(
|
||||
request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
|
||||
VLOG(3) << "Trying prefetch picker = "
|
||||
<< options_.prefetch_interval_picker->ToDebugString();
|
||||
|
||||
@ -1988,19 +2082,19 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
: 0;
|
||||
while (!options_.prefetch_interval_picker->Done()) {
|
||||
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
|
||||
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
|
||||
CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
|
||||
VLOG(4) << "Trying alternate memory allocation ("
|
||||
<< alternate_mem_interval.start << ", " << request.end_time << ")";
|
||||
// If this additional asynchronous copy would violate the limit, try a
|
||||
// different interval.
|
||||
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
|
||||
request.latest_prefetch_time)) {
|
||||
prefetch_end_time)) {
|
||||
VLOG(4) << "This would violate asynchronous copy ordering.";
|
||||
prefetch_failed_due_to_async_copy_ = true;
|
||||
continue;
|
||||
}
|
||||
if (ViolatesMaximumOutstandingAsyncCopies(
|
||||
alternate_mem_interval.start, request.latest_prefetch_time,
|
||||
alternate_mem_interval.start, prefetch_end_time,
|
||||
/*is_prefetch=*/true, extra_async_copy_limit)) {
|
||||
VLOG(4) << "This would violate the outstanding async copy limit.";
|
||||
prefetch_failed_due_to_async_copy_ = true;
|
||||
@ -2022,7 +2116,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
|
||||
AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
|
||||
chunk_candidate->chunk, alternate_mem_interval.start,
|
||||
request.end_time, request.latest_prefetch_time,
|
||||
request.end_time, prefetch_end_time,
|
||||
request.allocation_value->allocation_sequence());
|
||||
|
||||
request.allocation_value->allocation_sequence()->back()->AddUse(
|
||||
|
@ -198,6 +198,17 @@ class PrefetchIntervalPicker {
|
||||
virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
||||
int64 latest_end_time) const = 0;
|
||||
|
||||
// Returns the latest time that a prefetch can start.
|
||||
virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
|
||||
int64 end_time) const = 0;
|
||||
|
||||
// Returns the latest time that a prefetch can end that is less than or equal
|
||||
// to proposed_prefetch_end_time.
|
||||
virtual int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
|
||||
int64 proposed_prefetch_end_time) const {
|
||||
return proposed_prefetch_end_time;
|
||||
}
|
||||
|
||||
// Begins the iterator for the first start time of the prefetch.
|
||||
virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
|
||||
|
||||
@ -256,6 +267,9 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
||||
int64 latest_end_time) const override;
|
||||
|
||||
int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
|
||||
int64 end_time) const override;
|
||||
|
||||
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
|
||||
|
||||
int64 Next() override;
|
||||
@ -292,6 +306,11 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
|
||||
int64 latest_end_time) const override;
|
||||
|
||||
int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
|
||||
int64 end_time) const override;
|
||||
int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
|
||||
int64 proposed_prefetch_end_time) const override;
|
||||
|
||||
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
|
||||
|
||||
int64 Next() override;
|
||||
@ -395,6 +414,11 @@ class MemorySpaceAssignment {
|
||||
// max_outstanding_prefetches).
|
||||
int64 while_use_extra_outstanding_prefetch_limit = 0;
|
||||
|
||||
// Specifies the maximum number of times we are willing to move a copy
|
||||
// done of a prefetch earlier due to an asynchronous copy ordering
|
||||
// violation.
|
||||
int64 prefetch_copy_done_reorder_max_retries = 1;
|
||||
|
||||
// Specifies the maximum number of retries that will be performed for each
|
||||
// value in case prefetching failed due to running out of asynchronous
|
||||
// copies or asynchronous copy ordering.
|
||||
@ -850,9 +874,9 @@ class AsynchronousCopyOrdering {
|
||||
// Removes an asynchronous copy. CHECKs that it is removed.
|
||||
void RemoveCopy(const AsynchronousCopy& copy);
|
||||
|
||||
// Returns true if the addition of an asynchronous copy in the the given time
|
||||
// interval would violate the asynchronous copy ordering. E.g., consider the
|
||||
// following scenario:
|
||||
// If the addition of an asynchronous copy in the given time interval would
|
||||
// violate the asynchronous copy ordering, returns the violating
|
||||
// already-committed asynchronous copy. E.g., consider the following scenario:
|
||||
// CS CD
|
||||
// already committed async copy: +-----------+
|
||||
// new async copy: +--------+
|
||||
@ -860,7 +884,8 @@ class AsynchronousCopyOrdering {
|
||||
// The new asynchronous copy would violate the ordering guarantee because the
|
||||
// copy start is after an already committed asynchronous copy while its copy
|
||||
// done is before the committed copy.
|
||||
bool ViolatesOrdering(int64 start_time, int64 end_time) const;
|
||||
absl::optional<AsynchronousCopy> ViolatesOrdering(int64 start_time,
|
||||
int64 end_time) const;
|
||||
|
||||
private:
|
||||
// Stores asynchronous copies in a tree set respecting the pipelining order.
|
||||
@ -981,6 +1006,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
// Try evicting to default memory space. Returns true if successful.
|
||||
bool Evict(const AllocationRequest& request);
|
||||
|
||||
// Returns the time a copy done of a prefetch should be scheduled.
|
||||
int64 FindPrefetchEndTime(const AllocationRequest& request,
|
||||
int64 earliest_prefetch_time) const;
|
||||
|
||||
// Try prefetching to alternate memory space. Returns true if successful.
|
||||
bool Prefetch(
|
||||
const AllocationRequest& request,
|
||||
@ -1045,8 +1074,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
int64 start_time, int64 end_time, bool is_prefetch,
|
||||
int64 extra_async_copy_limit = 0) const;
|
||||
|
||||
// Return true if the asynchronous copy would violate the pipelining order.
|
||||
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
|
||||
// If the asynchronous copy would violate the pipelining order, returns the
|
||||
// violating asynchronous copy.
|
||||
absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering(
|
||||
int64 start_time, int64 end_time) const;
|
||||
|
||||
// Adds an asynchronous copy to the allocations.
|
||||
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
|
||||
|
@ -286,6 +286,92 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache cache_;
|
||||
};
|
||||
|
||||
// For testing purposes, we define a cost analysis where we can control the
|
||||
// elapsed times of each HLO and asynchronous copy.
|
||||
class FakeMemorySpaceAssignmentCostAnalysis
|
||||
: public MemorySpaceAssignmentCostAnalysis {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
|
||||
Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
|
||||
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
|
||||
HloLiveRange::Run(module.schedule(), *alias_analysis,
|
||||
module.entry_computation()));
|
||||
auto call_graph = CallGraph::Build(&module);
|
||||
return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
|
||||
cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
|
||||
/*alternate_mem_bandwidth_bytes_per_second=*/1,
|
||||
std::move(alias_analysis), std::move(hlo_live_range),
|
||||
std::move(call_graph)));
|
||||
}
|
||||
|
||||
float GetInstructionElapsed(
|
||||
const HloInstruction& instruction) const override {
|
||||
if (get_instruction_elapsed_override_) {
|
||||
return get_instruction_elapsed_override_(instruction);
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
float GetInstructionElapsedInAlternateMemory(
|
||||
const HloInstruction& instruction,
|
||||
absl::optional<int64> operand_in_alternate_mem,
|
||||
bool output_in_alternate_mem) const override {
|
||||
if (get_instruction_elapsed_in_alternate_memory_override_) {
|
||||
return get_instruction_elapsed_in_alternate_memory_override_(
|
||||
instruction, operand_in_alternate_mem, output_in_alternate_mem);
|
||||
}
|
||||
if (operand_in_alternate_mem) {
|
||||
return 0.5;
|
||||
} else {
|
||||
return 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
float GetAsyncCopyElapsed(const Shape& shape) const override {
|
||||
if (get_async_copy_elapsed_override_) {
|
||||
return get_async_copy_elapsed_override_(shape);
|
||||
}
|
||||
return 3.0;
|
||||
}
|
||||
|
||||
// The following methods can be used to override what the above API calls
|
||||
// return.
|
||||
void SetOverrideForGetInstructionElapsed(
|
||||
std::function<float(const HloInstruction&)> function) {
|
||||
get_instruction_elapsed_override_ = function;
|
||||
}
|
||||
void SetOverrideForGetInstructionElapsedInAlternateMemory(
|
||||
std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
|
||||
function) {
|
||||
get_instruction_elapsed_in_alternate_memory_override_ = function;
|
||||
}
|
||||
void SetOverrideForGetAsyncCopyElapsed(
|
||||
std::function<float(const Shape&)> function) {
|
||||
get_async_copy_elapsed_override_ = function;
|
||||
}
|
||||
|
||||
protected:
|
||||
FakeMemorySpaceAssignmentCostAnalysis(
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
float alternate_mem_bandwidth_bytes_per_second,
|
||||
std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||
std::unique_ptr<CallGraph> call_graph)
|
||||
: MemorySpaceAssignmentCostAnalysis(
|
||||
cost_analysis, async_copy_bandwidth_bytes_per_second,
|
||||
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
|
||||
std::move(hlo_live_range), std::move(call_graph)) {}
|
||||
|
||||
private:
|
||||
std::function<float(const HloInstruction&)>
|
||||
get_instruction_elapsed_override_ = nullptr;
|
||||
std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
|
||||
get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
|
||||
std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
|
||||
};
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
|
||||
// A module consisting of a single parameter. Inputs/outputs are currently
|
||||
// excluded from memory space assignment.
|
||||
@ -3750,6 +3836,123 @@ TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
|
||||
buffer_interval_compare, &prefetch_interval_picker);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) {
|
||||
// This tests the case where an earlier placed smaller buffer may block a
|
||||
// larger buffer due to asynchronous copy ordering. The smaller buffer (the
|
||||
// operand of sin) will be placed first. The cos, whose operand is 3 times
|
||||
// larger than sin's, needs longer time for the asynhronous copy. The cos is
|
||||
// placed right after sin, leading to a copy ordering violation:
|
||||
//
|
||||
// param1------------------>CS----->CD->sin
|
||||
// param0------------->CS------------------->CD->cos
|
||||
//
|
||||
// To fix this, we need to move copy done for cos earlier and ensure both of
|
||||
// these buffers get alternate memory allocations:
|
||||
//
|
||||
// param1------------------>CS----->CD->sin
|
||||
// param0-->CS------------------->CD------------>cos
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module, is_scheduled=true
|
||||
|
||||
ENTRY Entry {
|
||||
param0 = f32[8,3] parameter(0)
|
||||
param1 = f32[2,4] parameter(1)
|
||||
a = f32[2,4] negate(param1)
|
||||
b = f32[2,4] negate(a)
|
||||
c = f32[2,4] negate(b)
|
||||
d = f32[2,4] negate(c)
|
||||
e = f32[2,4] negate(d)
|
||||
f = f32[2,4] negate(e)
|
||||
g = f32[2,4] negate(f)
|
||||
h = f32[2,4] negate(g)
|
||||
i = f32[2,4] negate(h)
|
||||
j = f32[2,4] negate(i)
|
||||
k = f32[2,4] negate(j)
|
||||
l = f32[2,4] negate(k)
|
||||
m = f32[2,4] negate(l)
|
||||
n = f32[2,4] negate(m)
|
||||
sin = f32[2,4] sine(param1)
|
||||
o = f32[2,4] negate(n)
|
||||
cos = f32[8,3] cosine(param0)
|
||||
ROOT tuple = (f32[8,3], f32[2,4], f32[2,4]) tuple(cos, sin, o)
|
||||
}
|
||||
)";
|
||||
|
||||
MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
|
||||
[](const MemorySpaceAssignment::BufferInterval& a,
|
||||
const MemorySpaceAssignment::BufferInterval& b) {
|
||||
auto get_opcode_priority = [](const HloOpcode& opcode) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kSin:
|
||||
return 0;
|
||||
case HloOpcode::kCos:
|
||||
return 1;
|
||||
case HloOpcode::kTanh:
|
||||
return 2;
|
||||
default:
|
||||
return 3;
|
||||
}
|
||||
};
|
||||
|
||||
auto get_user_priority = [&](const HloValue& value) {
|
||||
int priority = INT_MAX;
|
||||
for (const auto& use : value.uses()) {
|
||||
priority = std::min(priority,
|
||||
get_opcode_priority(use.instruction->opcode()));
|
||||
}
|
||||
return priority;
|
||||
};
|
||||
|
||||
return get_user_priority(*a.buffer) < get_user_priority(*b.buffer);
|
||||
};
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloCostAnalysis hlo_cost_analysis(ShapeSize);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
|
||||
FakeMemorySpaceAssignmentCostAnalysis::Create(
|
||||
hlo_cost_analysis, *module));
|
||||
cost_analysis->SetOverrideForGetAsyncCopyElapsed([](const Shape& shape) {
|
||||
// This should return 2 for f32[2,4] and 6 for f32[8,3].
|
||||
return ShapeSize(shape) / 16;
|
||||
});
|
||||
CostAnalysisPrefetchIntervalPicker interval_picker(
|
||||
*cost_analysis,
|
||||
/*min_async_copy_to_overlap_ratio=*/1.0,
|
||||
/*max_async_copy_to_overlap_ratio=*/4.0,
|
||||
/*preferred_async_copy_to_overlap_ratio=*/1.5);
|
||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||
buffer_interval_compare, &interval_picker);
|
||||
|
||||
// Check that both cos and sin could get their operands prefetched.
|
||||
const HloInstruction* cos =
|
||||
module->entry_computation()->GetInstructionWithName("cos");
|
||||
const HloInstruction* sin =
|
||||
module->entry_computation()->GetInstructionWithName("sin");
|
||||
EXPECT_THAT(sin->operand(0),
|
||||
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
|
||||
op::Parameter(1)));
|
||||
EXPECT_THAT(cos->operand(0),
|
||||
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
|
||||
op::Parameter(0)));
|
||||
|
||||
// Sanity check that the cos' operand copy-done is scheduled earlier than
|
||||
// sin's operand.
|
||||
auto find_schedule_index = [&](const HloInstruction* instruction) {
|
||||
const auto& instructions =
|
||||
module->schedule().sequence(module->entry_computation()).instructions();
|
||||
for (int i = 0; i < instructions.size(); ++i) {
|
||||
if (instruction == instructions[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
CHECK(false);
|
||||
return -1;
|
||||
};
|
||||
EXPECT_GT(find_schedule_index(sin->operand(0)),
|
||||
find_schedule_index(cos->operand(0)));
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
||||
// Run memory space assignment a few times to make sure every time it compiles
|
||||
// to the same thing.
|
||||
@ -4046,57 +4249,6 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
|
||||
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
||||
}
|
||||
|
||||
// For testing purposes, we define a cost analysis where we can control the
|
||||
// elapsed times of each HLO and asynchronous copy.
|
||||
class FakeMemorySpaceAssignmentCostAnalysis
|
||||
: public MemorySpaceAssignmentCostAnalysis {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
|
||||
Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
|
||||
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
|
||||
HloLiveRange::Run(module.schedule(), *alias_analysis,
|
||||
module.entry_computation()));
|
||||
auto call_graph = CallGraph::Build(&module);
|
||||
return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
|
||||
cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
|
||||
/*alternate_mem_bandwidth_bytes_per_second=*/1,
|
||||
std::move(alias_analysis), std::move(hlo_live_range),
|
||||
std::move(call_graph)));
|
||||
}
|
||||
|
||||
float GetInstructionElapsed(
|
||||
const HloInstruction& instruction) const override {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
float GetInstructionElapsedInAlternateMemory(
|
||||
const HloInstruction& instruction,
|
||||
absl::optional<int64> operand_in_alternate_mem,
|
||||
bool output_in_alternate_mem) const override {
|
||||
if (operand_in_alternate_mem) {
|
||||
return 0.5;
|
||||
} else {
|
||||
return 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
float GetAsyncCopyElapsed(const Shape& shape) const override { return 3.0; }
|
||||
|
||||
protected:
|
||||
FakeMemorySpaceAssignmentCostAnalysis(
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
float alternate_mem_bandwidth_bytes_per_second,
|
||||
std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||
std::unique_ptr<CallGraph> call_graph)
|
||||
: MemorySpaceAssignmentCostAnalysis(
|
||||
cost_analysis, async_copy_bandwidth_bytes_per_second,
|
||||
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
|
||||
std::move(hlo_live_range), std::move(call_graph)) {}
|
||||
};
|
||||
|
||||
using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
|
||||
|
||||
TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
|
||||
|
Loading…
Reference in New Issue
Block a user