[XLA] Allow copy-done to be scheduled earlier to avoid copy ordering issues.

We previously required that copy-dones to be scheduled right before the use.
We also require asynchronous copies to maintain pipelining order (no nested
copy-start/copy-done pairs). This could mean that a smaller buffer that has
shorter copy-start/copy-done duration may block a larger buffer due to copy
ordering. E.g. this situation might arise which is not allowed:

  small tensor in default mem---------------------->CS----->CD->use
  large tensor in default mem------------------>CS----------------->CD->use
====================================================================> time

This CL checks if there is an already committed asynchronous copy that violates
the pipelining behavior. If so, we attempt to move the copy-done earlier:

  small tensor in default mem---------------------->CS----->CD->use
  large tensor in default mem-------->CS----------------->CD----------->use
====================================================================> time

PiperOrigin-RevId: 322621813
Change-Id: I8287c11c96a6d71de86a5f2e22cf1846c26ef4f3
This commit is contained in:
Berkin Ilbeyi 2020-07-22 11:50:46 -07:00 committed by TensorFlower Gardener
parent 30885b432e
commit 70edbdb6c7
3 changed files with 354 additions and 77 deletions

View File

@ -235,6 +235,11 @@ int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
return std::min(start_time + min_overlap_count_, latest_end_time);
}
int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
const HloUse& use, int64 start_time, int64 end_time) const {
return end_time_ - min_overlap_count_;
}
void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
int64 start_time,
int64 end_time) {
@ -355,6 +360,49 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
return end_time;
}
int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
const HloUse& use, int64 start_time, int64 end_time) const {
const Shape& shape = ShapeUtil::GetSubshape(
use.instruction->operand(use.operand_number)->shape(), use.operand_index);
// Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
// Estimate the time we would save by having this op in alternate memory.
float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
float elapsed_time_in_alternate_mem =
cost_analysis_.GetInstructionElapsedInAlternateMemory(
*use.instruction, use.operand_number,
/*output_in_alternate_mem=*/false);
float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
int end_nest_level = while_nest_level_[end_time];
// Find the latest time we're allowed to start prefetching.
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
int latest_prefetch_time;
for (latest_prefetch_time = end_time - 1;
latest_prefetch_time >= start_time &&
(while_nest_level_[latest_prefetch_time] != end_nest_level ||
min_interval >
GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
inst_elapsed_reduction);
--latest_prefetch_time) {
}
return latest_prefetch_time;
}
int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const {
// Iterate towards the beginning until we find a suitable end time that is the
// same while nest level as the original prefetch end time.
int64 original_nest_level = while_nest_level_[original_prefetch_end_time];
int64 new_prefetch_end_time;
for (new_prefetch_end_time = proposed_prefetch_end_time;
while_nest_level_[new_prefetch_end_time] != original_nest_level;
--new_prefetch_end_time) {
}
return new_prefetch_end_time;
}
void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
int64 start_time,
int64 end_time) {
@ -374,14 +422,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
// Find the latest time we're allowed to start prefetching.
float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
for (latest_prefetch_time_ = end_logical_time_ - 1;
latest_prefetch_time_ >= start_time &&
(while_nest_level_[latest_prefetch_time_] != end_nest_level ||
min_interval > GetLogicalIntervalElapsed(latest_prefetch_time_,
end_logical_time_) +
inst_elapsed_reduction_);
--latest_prefetch_time_) {
}
latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time);
// Find the earliest time we're allowed to start prefetching.
float max_interval = max_async_copy_to_overlap_ratio_ *
@ -1229,15 +1270,21 @@ void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
ranges_.erase(copy_it);
}
bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
int64 end_time) const {
absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
int64 start_time, int64 end_time) const {
// We allow identical start and end times. It is enough to check for just the
// start time in case we find a match in ranges_ because the found value will
// either be identical to {start_time, end_time} (and this doesn't violate) or
// its start_time will be smaller and end_time will be larger (this violates).
auto copy_it = ranges_.find(
{start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
return copy_it != ranges_.end() && copy_it->start_time != start_time;
if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
<< ") and (" << copy_it->start_time << ", " << copy_it->end_time
<< ")";
return *copy_it;
}
return absl::nullopt;
}
/*static*/ MemorySpaceAssignment::Allocation*
@ -1734,8 +1781,9 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
}
}
bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
int64 start_time, int64 end_time) const {
absl::optional<AsynchronousCopy>
AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time,
int64 end_time) const {
return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
}
@ -1945,6 +1993,50 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) {
return true;
}
int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
const AllocationRequest& request, int64 earliest_prefetch_time) const {
int64 prefetch_end_time = request.latest_prefetch_time;
for (int retry_number = 0;
retry_number < options_.prefetch_copy_done_reorder_max_retries;
++retry_number) {
int64 latest_prefetch_time =
options_.prefetch_interval_picker->LatestPrefetchStartTime(
request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
// Return if we couldn't find a suitable prefetch start time.
if (latest_prefetch_time < earliest_prefetch_time) {
break;
}
// Return either if there is no other violating asynchronous copy (since we
// don't need to change the prefetch end time) or if the violating
// asynchronous copy ends after the prefetch end time.
auto violating_async_copy =
ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
if (!violating_async_copy ||
violating_async_copy->end_time >= prefetch_end_time) {
break;
}
VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
<< ", " << violating_async_copy->end_time << ")";
int64 new_prefetch_end_time =
options_.prefetch_interval_picker->LatestPrefetchEndTime(
prefetch_end_time, violating_async_copy->end_time);
if (new_prefetch_end_time > earliest_prefetch_time) {
VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
prefetch_end_time = new_prefetch_end_time;
} else {
VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
<< " because earliest prefetch start time = "
<< earliest_prefetch_time;
break;
}
}
return prefetch_end_time;
}
bool AlternateMemoryBestFitHeap::Prefetch(
const AllocationRequest& request,
const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
@ -1966,9 +2058,11 @@ bool AlternateMemoryBestFitHeap::Prefetch(
earliest_prefetch_time =
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
}
options_.prefetch_interval_picker->Begin(request.use->hlo_use,
earliest_prefetch_time,
request.latest_prefetch_time);
int64 prefetch_end_time =
FindPrefetchEndTime(request, earliest_prefetch_time);
options_.prefetch_interval_picker->Begin(
request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
VLOG(3) << "Trying prefetch picker = "
<< options_.prefetch_interval_picker->ToDebugString();
@ -1988,19 +2082,19 @@ bool AlternateMemoryBestFitHeap::Prefetch(
: 0;
while (!options_.prefetch_interval_picker->Done()) {
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
VLOG(4) << "Trying alternate memory allocation ("
<< alternate_mem_interval.start << ", " << request.end_time << ")";
// If this additional asynchronous copy would violate the limit, try a
// different interval.
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
request.latest_prefetch_time)) {
prefetch_end_time)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
prefetch_failed_due_to_async_copy_ = true;
continue;
}
if (ViolatesMaximumOutstandingAsyncCopies(
alternate_mem_interval.start, request.latest_prefetch_time,
alternate_mem_interval.start, prefetch_end_time,
/*is_prefetch=*/true, extra_async_copy_limit)) {
VLOG(4) << "This would violate the outstanding async copy limit.";
prefetch_failed_due_to_async_copy_ = true;
@ -2022,7 +2116,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
chunk_candidate->chunk, alternate_mem_interval.start,
request.end_time, request.latest_prefetch_time,
request.end_time, prefetch_end_time,
request.allocation_value->allocation_sequence());
request.allocation_value->allocation_sequence()->back()->AddUse(

View File

@ -198,6 +198,17 @@ class PrefetchIntervalPicker {
virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
int64 latest_end_time) const = 0;
// Returns the latest time that a prefetch can start.
virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
int64 end_time) const = 0;
// Returns the latest time that a prefetch can end that is less than or equal
// to proposed_prefetch_end_time.
virtual int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
int64 proposed_prefetch_end_time) const {
return proposed_prefetch_end_time;
}
// Begins the iterator for the first start time of the prefetch.
virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
@ -256,6 +267,9 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
int64 latest_end_time) const override;
int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
int64 end_time) const override;
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
int64 Next() override;
@ -292,6 +306,11 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
int64 latest_end_time) const override;
int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
int64 end_time) const override;
int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
int64 proposed_prefetch_end_time) const override;
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
int64 Next() override;
@ -395,6 +414,11 @@ class MemorySpaceAssignment {
// max_outstanding_prefetches).
int64 while_use_extra_outstanding_prefetch_limit = 0;
// Specifies the maximum number of times we are willing to move a copy
// done of a prefetch earlier due to an asynchronous copy ordering
// violation.
int64 prefetch_copy_done_reorder_max_retries = 1;
// Specifies the maximum number of retries that will be performed for each
// value in case prefetching failed due to running out of asynchronous
// copies or asynchronous copy ordering.
@ -850,9 +874,9 @@ class AsynchronousCopyOrdering {
// Removes an asynchronous copy. CHECKs that it is removed.
void RemoveCopy(const AsynchronousCopy& copy);
// Returns true if the addition of an asynchronous copy in the the given time
// interval would violate the asynchronous copy ordering. E.g., consider the
// following scenario:
// If the addition of an asynchronous copy in the given time interval would
// violate the asynchronous copy ordering, returns the violating
// already-committed asynchronous copy. E.g., consider the following scenario:
// CS CD
// already committed async copy: +-----------+
// new async copy: +--------+
@ -860,7 +884,8 @@ class AsynchronousCopyOrdering {
// The new asynchronous copy would violate the ordering guarantee because the
// copy start is after an already committed asynchronous copy while its copy
// done is before the committed copy.
bool ViolatesOrdering(int64 start_time, int64 end_time) const;
absl::optional<AsynchronousCopy> ViolatesOrdering(int64 start_time,
int64 end_time) const;
private:
// Stores asynchronous copies in a tree set respecting the pipelining order.
@ -981,6 +1006,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// Try evicting to default memory space. Returns true if successful.
bool Evict(const AllocationRequest& request);
// Returns the time a copy done of a prefetch should be scheduled.
int64 FindPrefetchEndTime(const AllocationRequest& request,
int64 earliest_prefetch_time) const;
// Try prefetching to alternate memory space. Returns true if successful.
bool Prefetch(
const AllocationRequest& request,
@ -1045,8 +1074,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
int64 start_time, int64 end_time, bool is_prefetch,
int64 extra_async_copy_limit = 0) const;
// Return true if the asynchronous copy would violate the pipelining order.
bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
// If the asynchronous copy would violate the pipelining order, returns the
// violating asynchronous copy.
absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering(
int64 start_time, int64 end_time) const;
// Adds an asynchronous copy to the allocations.
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,

View File

@ -286,6 +286,92 @@ class MemorySpaceAssignmentTest : public HloTestBase,
MemorySpaceAssignmentCostAnalysis::Cache cache_;
};
// For testing purposes, we define a cost analysis where we can control the
// elapsed times of each HLO and asynchronous copy.
class FakeMemorySpaceAssignmentCostAnalysis
: public MemorySpaceAssignmentCostAnalysis {
public:
static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
HloLiveRange::Run(module.schedule(), *alias_analysis,
module.entry_computation()));
auto call_graph = CallGraph::Build(&module);
return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
/*alternate_mem_bandwidth_bytes_per_second=*/1,
std::move(alias_analysis), std::move(hlo_live_range),
std::move(call_graph)));
}
float GetInstructionElapsed(
const HloInstruction& instruction) const override {
if (get_instruction_elapsed_override_) {
return get_instruction_elapsed_override_(instruction);
}
return 1.0;
}
float GetInstructionElapsedInAlternateMemory(
const HloInstruction& instruction,
absl::optional<int64> operand_in_alternate_mem,
bool output_in_alternate_mem) const override {
if (get_instruction_elapsed_in_alternate_memory_override_) {
return get_instruction_elapsed_in_alternate_memory_override_(
instruction, operand_in_alternate_mem, output_in_alternate_mem);
}
if (operand_in_alternate_mem) {
return 0.5;
} else {
return 1.0;
}
}
float GetAsyncCopyElapsed(const Shape& shape) const override {
if (get_async_copy_elapsed_override_) {
return get_async_copy_elapsed_override_(shape);
}
return 3.0;
}
// The following methods can be used to override what the above API calls
// return.
void SetOverrideForGetInstructionElapsed(
std::function<float(const HloInstruction&)> function) {
get_instruction_elapsed_override_ = function;
}
void SetOverrideForGetInstructionElapsedInAlternateMemory(
std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
function) {
get_instruction_elapsed_in_alternate_memory_override_ = function;
}
void SetOverrideForGetAsyncCopyElapsed(
std::function<float(const Shape&)> function) {
get_async_copy_elapsed_override_ = function;
}
protected:
FakeMemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
std::unique_ptr<HloAliasAnalysis> alias_analysis,
std::unique_ptr<HloLiveRange> hlo_live_range,
std::unique_ptr<CallGraph> call_graph)
: MemorySpaceAssignmentCostAnalysis(
cost_analysis, async_copy_bandwidth_bytes_per_second,
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
std::move(hlo_live_range), std::move(call_graph)) {}
private:
std::function<float(const HloInstruction&)>
get_instruction_elapsed_override_ = nullptr;
std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
};
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
// A module consisting of a single parameter. Inputs/outputs are currently
// excluded from memory space assignment.
@ -3750,6 +3836,123 @@ TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
buffer_interval_compare, &prefetch_interval_picker);
}
TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) {
// This tests the case where an earlier placed smaller buffer may block a
// larger buffer due to asynchronous copy ordering. The smaller buffer (the
// operand of sin) will be placed first. The cos, whose operand is 3 times
// larger than sin's, needs longer time for the asynhronous copy. The cos is
// placed right after sin, leading to a copy ordering violation:
//
// param1------------------>CS----->CD->sin
// param0------------->CS------------------->CD->cos
//
// To fix this, we need to move copy done for cos earlier and ensure both of
// these buffers get alternate memory allocations:
//
// param1------------------>CS----->CD->sin
// param0-->CS------------------->CD------------>cos
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
ENTRY Entry {
param0 = f32[8,3] parameter(0)
param1 = f32[2,4] parameter(1)
a = f32[2,4] negate(param1)
b = f32[2,4] negate(a)
c = f32[2,4] negate(b)
d = f32[2,4] negate(c)
e = f32[2,4] negate(d)
f = f32[2,4] negate(e)
g = f32[2,4] negate(f)
h = f32[2,4] negate(g)
i = f32[2,4] negate(h)
j = f32[2,4] negate(i)
k = f32[2,4] negate(j)
l = f32[2,4] negate(k)
m = f32[2,4] negate(l)
n = f32[2,4] negate(m)
sin = f32[2,4] sine(param1)
o = f32[2,4] negate(n)
cos = f32[8,3] cosine(param0)
ROOT tuple = (f32[8,3], f32[2,4], f32[2,4]) tuple(cos, sin, o)
}
)";
MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
[](const MemorySpaceAssignment::BufferInterval& a,
const MemorySpaceAssignment::BufferInterval& b) {
auto get_opcode_priority = [](const HloOpcode& opcode) {
switch (opcode) {
case HloOpcode::kSin:
return 0;
case HloOpcode::kCos:
return 1;
case HloOpcode::kTanh:
return 2;
default:
return 3;
}
};
auto get_user_priority = [&](const HloValue& value) {
int priority = INT_MAX;
for (const auto& use : value.uses()) {
priority = std::min(priority,
get_opcode_priority(use.instruction->opcode()));
}
return priority;
};
return get_user_priority(*a.buffer) < get_user_priority(*b.buffer);
};
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloCostAnalysis hlo_cost_analysis(ShapeSize);
TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
FakeMemorySpaceAssignmentCostAnalysis::Create(
hlo_cost_analysis, *module));
cost_analysis->SetOverrideForGetAsyncCopyElapsed([](const Shape& shape) {
// This should return 2 for f32[2,4] and 6 for f32[8,3].
return ShapeSize(shape) / 16;
});
CostAnalysisPrefetchIntervalPicker interval_picker(
*cost_analysis,
/*min_async_copy_to_overlap_ratio=*/1.0,
/*max_async_copy_to_overlap_ratio=*/4.0,
/*preferred_async_copy_to_overlap_ratio=*/1.5);
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
buffer_interval_compare, &interval_picker);
// Check that both cos and sin could get their operands prefetched.
const HloInstruction* cos =
module->entry_computation()->GetInstructionWithName("cos");
const HloInstruction* sin =
module->entry_computation()->GetInstructionWithName("sin");
EXPECT_THAT(sin->operand(0),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::Parameter(1)));
EXPECT_THAT(cos->operand(0),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::Parameter(0)));
// Sanity check that the cos' operand copy-done is scheduled earlier than
// sin's operand.
auto find_schedule_index = [&](const HloInstruction* instruction) {
const auto& instructions =
module->schedule().sequence(module->entry_computation()).instructions();
for (int i = 0; i < instructions.size(); ++i) {
if (instruction == instructions[i]) {
return i;
}
}
CHECK(false);
return -1;
};
EXPECT_GT(find_schedule_index(sin->operand(0)),
find_schedule_index(cos->operand(0)));
}
TEST_P(MemorySpaceAssignmentTest, Determinism) {
// Run memory space assignment a few times to make sure every time it compiles
// to the same thing.
@ -4046,57 +4249,6 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
// For testing purposes, we define a cost analysis where we can control the
// elapsed times of each HLO and asynchronous copy.
class FakeMemorySpaceAssignmentCostAnalysis
: public MemorySpaceAssignmentCostAnalysis {
public:
static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
HloLiveRange::Run(module.schedule(), *alias_analysis,
module.entry_computation()));
auto call_graph = CallGraph::Build(&module);
return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
/*alternate_mem_bandwidth_bytes_per_second=*/1,
std::move(alias_analysis), std::move(hlo_live_range),
std::move(call_graph)));
}
float GetInstructionElapsed(
const HloInstruction& instruction) const override {
return 1.0;
}
float GetInstructionElapsedInAlternateMemory(
const HloInstruction& instruction,
absl::optional<int64> operand_in_alternate_mem,
bool output_in_alternate_mem) const override {
if (operand_in_alternate_mem) {
return 0.5;
} else {
return 1.0;
}
}
float GetAsyncCopyElapsed(const Shape& shape) const override { return 3.0; }
protected:
FakeMemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
std::unique_ptr<HloAliasAnalysis> alias_analysis,
std::unique_ptr<HloLiveRange> hlo_live_range,
std::unique_ptr<CallGraph> call_graph)
: MemorySpaceAssignmentCostAnalysis(
cost_analysis, async_copy_bandwidth_bytes_per_second,
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
std::move(hlo_live_range), std::move(call_graph)) {}
};
using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {