[XLA] Fix alternate memory allocation of conditional operands.

Consider the following flattened HLO schedule of a conditional:

1: a = fusion()
   true_computation:
2:    parameter = parameter(0)
3:    ...
4:    ...
   false_computation:
5:    parameter = parameter(0)
6:    ...
7:    ...
8: conditional = conditional(pred, a, a)
9: b = fusion(a)

When we had a tensor that was a conditional operand (e.g. "a" in the example),
we reserved the alternate memory for the entire 1-8 range. This meant that when
we tried to allocate inside the called computations of the conditional, the
offset we picked wasn't available since it would fall within the 1-8 range. This
CL now reserves the conditional until the parameter of the earliest called
computations (1-2 range).

To allow efficient use of alternate memory by avoiding a very large conditional
from claiming the offset for the entire called computation, the conditional
operand might die within the called computation, allowing other HLOs inside the
called computations to reclaim that alternate memory offset. This creates a
subtlety for subsequent uses of conditional operands (e.g. "a" is used by a
fusion at 9). These subsequent uses will force evictions (and then do another
prefetch). After optimization, the graph might look like the following:

  a (Alternate Mem) = fusion()
  cs0 = copy-start(a)  # Must evict a because the allocation may die within
                       # called computation.
  cd0 (Default Mem) = copy-done(cs0)
  true_computation:
    parameter (Alternate Mem) = parameter(0)
    ...
    # parameter's alternate memory allocation may die here and another tensor
    # might use the same offset.
  false_computation:
    parameter (Alternate Mem) = parameter(0)
    ...
    # parameter's alternate memory allocation may die here and another tensor
    # might use the same offset.
  conditional = conditional(pred, a, a)
  cs1 = copy-start(cd0)  # May prefetch the value back to alternate memory.
  cd1 (Alternate Mem) = copy-done(cs1)
  b = fusion(cd1)

PiperOrigin-RevId: 312182824
Change-Id: I3ff5d019025ef96ced1aed4f6d170df677273348
This commit is contained in:
Berkin Ilbeyi 2020-05-18 17:01:57 -07:00 committed by TensorFlower Gardener
parent ad6e816328
commit ad6798a2f6
3 changed files with 563 additions and 72 deletions

View File

@ -502,7 +502,8 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
}
bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
const HloUse& use) const {
const AllocationValue& value, const HloUse& use) const {
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
if (use.instruction->opcode() == HloOpcode::kWhile) {
HloComputation* while_body = use.instruction->while_body();
@ -512,7 +513,6 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
HloValue* parameter_value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
while_body->parameter_instruction(0), use.operand_index);
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
int64 parameter_time =
instruction_schedule.at(while_body->parameter_instruction(0));
int64 root_time = instruction_schedule.at(while_body->root_instruction());
@ -567,7 +567,54 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
"there is a required default memory assignment.";
return false;
}
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
// For any use of this conditional (the same value might be passed into
// multiple called computations), determine if the parameter->first use
// dependency is short.
int64 conditional_time = instruction_schedule.at(use.instruction);
for (const HloUse& other_use : value.uses()) {
if (other_use.instruction != use.instruction) {
continue;
}
HloComputation* called_computation =
use.instruction->called_computations().at(other_use.operand_number -
1);
const HloInstruction* parameter_instruction =
called_computation->parameter_instruction(0);
HloValue* parameter_value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
parameter_instruction, other_use.operand_index);
int64 parameter_time = instruction_schedule.at(parameter_instruction);
int64 min_use_time = conditional_time;
for (const HloUse& parameter_use : parameter_value->uses()) {
if (parameter_use.instruction->parent() == called_computation &&
parameter_use.instruction->opcode() !=
HloOpcode::kGetTupleElement &&
parameter_use.instruction->opcode() != HloOpcode::kTuple &&
parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
min_use_time = std::min(
min_use_time, instruction_schedule.at(parameter_use.instruction));
}
}
if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
parameter_value->shape(), parameter_time, min_use_time)) {
VLOG(4) << "Conditional allocation allowed in alternate memory for "
"computation = "
<< called_computation->name()
<< ", parameter time = " << parameter_time
<< ", min use time = " << min_use_time;
return true;
} else {
VLOG(4) << "Conditional allocation not allowed in alternate memory for "
"computation = "
<< called_computation->name()
<< ", parameter time = " << parameter_time
<< ", min use time = " << min_use_time;
}
}
return false;
}
return true;
}
@ -769,20 +816,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
if (position.instruction->opcode() == HloOpcode::kConditional) {
VLOG(3) << "Adding required assignment for condition output: "
<< value->ToShortString();
required_assignments_[value].push_back(
{MemorySpace::kDefault,
instruction_schedule.at(position.instruction),
/*chunk=*/absl::nullopt});
AddRequiredAssignment(position.instruction, position.index,
MemorySpace::kDefault);
for (const HloComputation* called_computation :
position.instruction->called_computations()) {
HloValue* root_value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(
called_computation->root_instruction(), position.index);
required_assignments_[root_value].push_back(
{MemorySpace::kDefault,
instruction_schedule.at(
called_computation->root_instruction()),
/*chunk=*/absl::nullopt});
AddRequiredAssignment(called_computation->root_instruction(),
position.index, MemorySpace::kDefault);
}
}
}
@ -808,9 +847,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
// Iterate over the uses.
for (HloUse use : allocation_value.uses()) {
for (int use_idx = 0; use_idx < allocation_value.uses().size();
++use_idx) {
const HloUse& use = allocation_value.uses().at(use_idx);
int64 use_time = instruction_schedule.at(use.instruction);
int64 latest_prefetch_time = use_time;
bool allow_no_copy_alternate_mem_allocation = true;
absl::optional<int64> earliest_prefetch_time = absl::nullopt;
// Sequential calls include kWhile, kCall, and kConditional opcodes.
bool is_sequential_call =
@ -857,14 +900,41 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
// when we look at uses within the while loop body.
use_time =
instruction_schedule.at(while_body->parameter_instruction(0));
} else if (use.instruction->opcode() == HloOpcode::kConditional) {
// Replace the use time with the earliest parameter of called
// computations.
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
use_time = std::min(
use_time, instruction_schedule.at(
called_computation->parameter_instruction(0)));
}
}
}
// Add a required assignment in default memory if the use not allowed in
// alternate memory.
if (!IsUseAllowedInAlternateMemory(use)) {
required_assignments_[allocation_value.value()].push_back(
{MemorySpace::kDefault, use_time, /*chunk=*/absl::nullopt});
if (!IsUseAllowedInAlternateMemory(allocation_value, use)) {
AddRequiredAssignment(allocation_value.value(), use.instruction,
MemorySpace::kDefault, use_time);
} else if (use_idx > 0) {
// We allow buffers in alternate memory that are passed into
// conditionals to give up their alternate memory allocation inside
// the called computation. This means that if a conditional operator
// has an alternate memory allocation, subsequent uses cannot use the
// same alternate memory allocation in order not to clobber data. So
// we force default memory allocation for these subsequent uses.
const HloUse& previous_use = allocation_value.uses().at(use_idx - 1);
if (previous_use.instruction->opcode() == HloOpcode::kConditional &&
previous_use.instruction != use.instruction) {
allow_no_copy_alternate_mem_allocation = false;
earliest_prefetch_time =
instruction_schedule.at(previous_use.instruction);
VLOG(3) << "Previous use (" << previous_use.ToString()
<< ") of use (" << use.ToString()
<< ") is a conditional, so this use will need to evict. "
<< "Earliest prefetch time = " << *earliest_prefetch_time;
}
}
// Bitcasts don't define buffers and don't directly consume buffers.
@ -872,10 +942,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
// bitcasts will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
AllocationRequest request;
request.start_time = definition_time;
// Rarely, (e.g., when conditional true and false parameters are the
// same), definition time can be the time of the conditional and use
// time is the parameter use, which is less.
request.start_time = std::min(definition_time, use_time);
request.end_time = use_time;
request.latest_prefetch_time = latest_prefetch_time;
request.size = interval.size;
request.allow_no_copy_alternate_mem_allocation =
allow_no_copy_alternate_mem_allocation;
request.earliest_prefetch_time = earliest_prefetch_time;
request.preferred_offset = preferred_offset;
request.use = use;
request.allocation_value = &allocation_value;
@ -1061,35 +1137,42 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
chunk = aliased_allocation->chunk();
}
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
HloValue* value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
int64 instruction_time = instruction_schedule.at(instruction);
AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
chunk);
}
void AlternateMemoryBestFitHeap::AddRequiredAssignment(
const HloValue* value, const HloInstruction* instruction,
MemorySpaceAssignment::MemorySpace memory_space, int64 time,
absl::optional<HeapSimulator::Chunk> chunk) {
// Check for existing required assignment at this time and make sure it is the
// same as this if there is one.
auto existing_required_assignment =
RequiredMemoryAssignmentAt(value, instruction_time);
auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
if (existing_required_assignment) {
CHECK(aliased_allocation->memory_space() ==
existing_required_assignment->memory_space);
CHECK(memory_space == existing_required_assignment->memory_space)
<< "inst = " << instruction->ToString() << " at " << time;
CHECK((!chunk && !existing_required_assignment->chunk) ||
chunk->offset == existing_required_assignment->chunk->offset);
VLOG(3) << "Not adding aliased required assignment because there is one "
"already: "
<< value->ToShortString() << " at " << instruction_time << " at "
<< (aliased_allocation->memory_space() == MemorySpace::kDefault
? "def"
: "alt");
return;
VLOG(3) << "Not adding required assignment because there is one already: "
<< value->ToShortString() << " at " << time << " at "
<< (memory_space == MemorySpace::kDefault ? "def" : "alt");
} else {
VLOG(3) << "Adding required assignment: " << value->ToShortString()
<< " at " << time << " at "
<< (memory_space == MemorySpace::kDefault ? "def" : "alt");
required_assignments_[value].push_back({memory_space, time, chunk});
}
}
required_assignments_[value].push_back(
{aliased_allocation->memory_space(), instruction_time, chunk});
VLOG(3) << "Adding aliased required assignment: " << value->ToShortString()
<< " at " << instruction_time << " at "
<< (aliased_allocation->memory_space() == MemorySpace::kDefault
? "def"
: "alt");
void AlternateMemoryBestFitHeap::AddRequiredAssignment(
const HloInstruction* instruction, ShapeIndex index,
MemorySpace memory_space, absl::optional<Chunk> chunk) {
const HloValue* value =
&alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
int64 instruction_time =
hlo_live_range_.instruction_schedule().at(instruction);
AddRequiredAssignment(value, instruction, memory_space, instruction_time,
chunk);
}
void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
@ -1289,6 +1372,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
// First try keeping the allocation entirely in the alternate memory.
if (required_memory_space_at_start != MemorySpace::kDefault &&
required_memory_space_at_end != MemorySpace::kDefault &&
request.allow_no_copy_alternate_mem_allocation &&
AllocateInAlternateMemoryNoCopy(request)) {
return true;
}
@ -1618,9 +1702,14 @@ bool AlternateMemoryBestFitHeap::Prefetch(
// ^ ^
// Copy Copy
// Start Done
options_.prefetch_interval_picker->Begin(
request.use, prev_allocation_in_default_mem.earliest_available_time(),
request.latest_prefetch_time);
int64 earliest_prefetch_time =
prev_allocation_in_default_mem.earliest_available_time();
if (request.earliest_prefetch_time) {
earliest_prefetch_time =
std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
}
options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time,
request.latest_prefetch_time);
VLOG(3) << "Trying prefetch picker = "
<< options_.prefetch_interval_picker->ToDebugString();
@ -2435,6 +2524,34 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
events;
auto add_allocation_and_verify = [&](int64 start_time, int64 end_time,
const Chunk& chunk,
const HloValue* value) {
events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
// Get the chunks overlapping in time and search if they overlap in space
// as well.
// TODO(berkin): For now checking against end_time - 1 (exclusive), but we
// really should check against end_time (inclusive) for cases where the
// operand can't share buffer with user (see
// HloDataflowAnalysis::CanShareOperandBufferWithUser).
for (const Chunk& overlapping_chunk :
interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
if (chunk.OverlapsWith(overlapping_chunk)) {
return InternalError(
("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
" off: %d size: %d"),
value->ToShortString(), start_time, end_time, chunk.offset,
chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
}
}
interval_tree.Add(start_time, end_time - 1, chunk);
return Status::OK();
};
// Go through all instructions in the module to ensure CopyStart/CopyDone
// instructions copy between alternate memory and default memory.
for (const HloComputation* computation :
@ -2470,34 +2587,73 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
for (const HloValue* value : buffer.values()) {
const HloLiveRange::TimeBound& time_bound =
hlo_live_range->buffer_live_ranges().at(value);
events[std::make_tuple(time_bound.start, /*is_free=*/false,
value->id())] =
std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
events[std::make_tuple(time_bound.end, /*is_free=*/true, value->id())] =
std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
VLOG(3) << " buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": ("
<< time_bound.start << ", " << time_bound.end
<< ") off: " << chunk.offset << ", size: " << chunk.size;
// Get the chunks overlapping in time and search if they overlap in space
// as well.
// TODO(berkin): For now checking against end_time - 1 (exclusive), but we
// really should check against end_time (inclusive) for cases where the
// operand can't share buffer with user (see
// HloDataflowAnalysis::CanShareOperandBufferWithUser).
for (const Chunk& overlapping_chunk :
interval_tree.ChunksOverlappingInTime(time_bound.start,
time_bound.end - 1)) {
if (chunk.OverlapsWith(overlapping_chunk)) {
return InternalError(
("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
" off: %d size: %d"),
buffer.ToString(), time_bound.start, time_bound.end, chunk.offset,
chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
const HloInstruction* last_use_instruction = nullptr;
int64 last_use_time = time_bound.start;
for (const HloUse& use : value->uses()) {
int64 use_time =
hlo_live_range->instruction_schedule().at(use.instruction);
if (use_time > last_use_time) {
last_use_time = use_time;
last_use_instruction = use.instruction;
}
}
interval_tree.Add(time_bound.start, time_bound.end - 1, chunk);
if (last_use_instruction &&
last_use_instruction->opcode() == HloOpcode::kConditional) {
// Special case when verifying conditional: we internally split the use
// of alternate memory in conditionals, so fish them out from the
// conditionals.
VLOG(3) << " Splitting conditional buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": ("
<< time_bound.start << ", " << time_bound.end
<< ") off: " << chunk.offset << ", size: " << chunk.size;
int64 earliest_computation_start_time = time_bound.end;
for (const HloComputation* called_computation :
last_use_instruction->called_computations()) {
earliest_computation_start_time =
std::min(earliest_computation_start_time,
hlo_live_range->computation_span_times()
.at(called_computation)
.start);
int64 parameter_time = -1;
int64 last_use_time = -1;
for (const HloPosition& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kParameter &&
position.instruction->parent() == called_computation) {
parameter_time = hlo_live_range->instruction_schedule().at(
position.instruction);
break;
}
}
for (const HloUse& use : value->uses()) {
if (use.instruction->parent() == called_computation) {
last_use_time = std::max(
last_use_time,
hlo_live_range->instruction_schedule().at(use.instruction));
}
}
if (last_use_time != -1) {
CHECK_NE(parameter_time, -1);
VLOG(3) << " computation: " << called_computation->name() << ": ("
<< parameter_time << ", " << last_use_time << ")";
TF_RETURN_IF_ERROR(add_allocation_and_verify(
parameter_time, last_use_time, chunk, value));
}
}
VLOG(3) << " from beginning until first computation: ("
<< time_bound.start << ", "
<< (earliest_computation_start_time - 1) << ")";
TF_RETURN_IF_ERROR(add_allocation_and_verify(
time_bound.start, earliest_computation_start_time - 1, chunk,
value));
} else {
VLOG(3) << " buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": ("
<< time_bound.start << ", " << time_bound.end
<< ") off: " << chunk.offset << ", size: " << chunk.size;
TF_RETURN_IF_ERROR(add_allocation_and_verify(
time_bound.start, time_bound.end, chunk, value));
}
}
}

View File

@ -816,11 +816,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// use_times is a sorted sequence of the times of all uses.
// latest_prefetch_time is the latest time we can schedule the CopyDone for a
// prefetch.
// If allow_no_copy_alternate_mem_allocation is false, an eviction is forced.
// If earliest_prefetch_time is set, prefetches cannot start before this
// value.
struct AllocationRequest {
int64 start_time;
int64 end_time;
int64 latest_prefetch_time;
int64 size;
bool allow_no_copy_alternate_mem_allocation;
absl::optional<int64> earliest_prefetch_time;
absl::optional<int64> preferred_offset;
HloUse use;
MemorySpaceAssignment::AllocationValue* allocation_value;
@ -841,7 +846,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
// Returns true if the use is allowed in the alternate memory.
bool IsUseAllowedInAlternateMemory(const HloUse& use) const;
bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
const HloUse& use) const;
// Given an HloValue, creates AllocationValue objects and corresponding
// AllocationSequences and appends them into allocation_sequence_list_.
@ -895,6 +901,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
const HloInstruction* instruction, ShapeIndex index,
const MemorySpaceAssignment::Allocation* aliased_allocation);
// This sets a required assignment. CHECK fails if there is a conflicting
// required assignment at the same time.
void AddRequiredAssignment(const HloValue* value,
const HloInstruction* instruction,
MemorySpace memory_space, int64 time,
absl::optional<Chunk> chunk = absl::nullopt);
void AddRequiredAssignment(const HloInstruction* instruction,
ShapeIndex index, MemorySpace memory_space,
absl::optional<Chunk> chunk = absl::nullopt);
// Adds input and outputs as required assignments.
void AddInputAndOutputRequiredAssignments();

View File

@ -1663,6 +1663,324 @@ TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) {
// Checks if simple conditionals get alternate memory allocations.
absl::string_view hlo_string = R"(
HloModule CondAllocation, is_scheduled=true
true_computation {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg1 = f32[3]{0} negate(gte)
}
false_computation {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg2 = f32[3]{0} negate(gte)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy = f32[3]{0} copy(p0)
tuple = (f32[3]{0}) tuple(copy)
ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
if (GetParam()) {
// Check that copy and gtes got alternate memory allocations.
auto copy =
module->GetComputationWithName("entry")->GetInstructionWithName("copy");
EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
auto neg1 = module->GetComputationWithName("true_computation")
->GetInstructionWithName("neg1");
auto neg1_operand = neg1->operand(0);
EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
auto neg2 = module->GetComputationWithName("false_computation")
->GetInstructionWithName("neg2");
auto neg2_operand = neg2->operand(0);
EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
}
}
TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) {
// Checks if we avoid unnecessary allocation in alternate memory if the input
// won't be used in the computation for a long time.
absl::string_view hlo_string = R"(
HloModule CondAllocation, is_scheduled=true
true_computation {
p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
gte0 = f32[3]{0} get-tuple-element(p0), index=0
neg0 = f32[3]{0} negate(gte0)
neg1 = f32[3]{0} negate(neg0)
neg2 = f32[3]{0} negate(neg1)
neg3 = f32[3]{0} negate(neg2)
neg4 = f32[3]{0} negate(neg3)
neg5 = f32[3]{0} negate(neg4)
neg6 = f32[3]{0} negate(neg5)
neg7 = f32[3]{0} negate(neg6)
neg8 = f32[3]{0} negate(neg7)
neg9 = f32[3]{0} negate(neg8)
gte1 = f32[3]{0} get-tuple-element(p0), index=1
ROOT add = f32[3]{0} add(neg9, gte1)
}
false_computation {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg = f32[3]{0} negate(gte)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy0 = f32[3]{0} copy(p0)
copy1 = f32[3]{0} copy(p0)
tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
tuple1 = (f32[3]{0}) tuple(copy0)
ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
if (GetParam()) {
// Check that copy1 doesn't get unnecessarily allocated in alternate mem
// (due to long negate chain in true_computation) but is prefetched before
// add.
auto copy0 =
module->GetComputationWithName("entry")->GetInstructionWithName(
"copy0");
EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
auto copy1 =
module->GetComputationWithName("entry")->GetInstructionWithName(
"copy1");
EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace);
auto add = module->GetComputationWithName("true_computation")
->GetInstructionWithName("add");
auto add_operand = add->operand(1);
EXPECT_EQ(add_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
}
}
TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) {
// Make sure there is an evict when there is a conditional use followed by
// another use.
absl::string_view hlo_string = R"(
HloModule CondAllocation, is_scheduled=true
true_computation {
p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
gte0 = f32[3]{0} get-tuple-element(p0), index=0
gte1 = f32[3]{0} get-tuple-element(p0), index=1
add0 = f32[3]{0} add(gte0, gte1)
neg0 = f32[3]{0} negate(add0)
neg1 = f32[3]{0} negate(neg0)
neg2 = f32[3]{0} negate(neg1)
neg3 = f32[3]{0} negate(neg2)
neg4 = f32[3]{0} negate(neg3)
neg5 = f32[3]{0} negate(neg4)
neg6 = f32[3]{0} negate(neg5)
neg7 = f32[3]{0} negate(neg6)
neg8 = f32[3]{0} negate(neg7)
ROOT neg9 = f32[3]{0} negate(neg8)
}
false_computation {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg = f32[3]{0} negate(gte)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy0 = f32[3]{0} copy(p0)
copy1 = f32[3]{0} copy(p0)
tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
tuple1 = (f32[3]{0}) tuple(copy0)
conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
ROOT add1 = f32[3]{0} add(copy1, conditional)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
if (GetParam()) {
// Make sure the copy1->add edge is in alternate memory. Before conditional,
// this should be evicted to default memory and neg uses the input from
// default memory.
auto copy1 =
module->GetComputationWithName("entry")->GetInstructionWithName(
"copy1");
EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace);
auto add0 = module->GetComputationWithName("true_computation")
->GetInstructionWithName("add0");
auto add0_operand = add0->operand(1);
EXPECT_EQ(add0_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
auto add1 =
module->GetComputationWithName("entry")->GetInstructionWithName("add1");
auto add1_operand = add1->operand(0);
EXPECT_EQ(add1_operand->shape().layout().memory_space(),
kDefaultMemorySpace);
EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone);
}
}
TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) {
absl::string_view hlo_string = R"(
HloModule CondAllocation, is_scheduled=true
true_computation {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg1 = f32[3]{0} negate(gte)
}
false_computation {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg2 = f32[3]{0} negate(gte)
}
while_cond {
p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(p0), index=2
}
while_body {
p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
gte0 = f32[3]{0} get-tuple-element(p0), index=0
gte1 = f32[3]{0} get-tuple-element(p0), index=1
gte2 = pred[] get-tuple-element(p0), index=2
cond_tuple = (f32[3]{0}) tuple(gte0)
conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation
add = f32[3]{0} add(conditional, gte1)
neg0 = f32[3]{0} negate(add)
neg1 = f32[3]{0} negate(neg0)
ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy0 = f32[3]{0} copy(p0)
copy1 = f32[3]{0} copy(p0)
tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1)
while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
ROOT gte = f32[3]{0} get-tuple-element(while), index=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
if (GetParam()) {
// Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation.
// This will force an eviction and a prefetch for while body root.
auto copy0 =
module->GetComputationWithName("entry")->GetInstructionWithName(
"copy0");
EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
auto conditional = module->GetComputationWithName("while_body")
->GetInstructionWithName("conditional");
auto conditional_operand = conditional->operand(1);
EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0})
.layout()
.memory_space(),
kAlternateMemorySpace);
auto while_root =
module->GetComputationWithName("while_body")->root_instruction();
auto while_root_operand = while_root->operand(0);
EXPECT_THAT(
while_root_operand,
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
op::GetTupleElement(op::Parameter(0)))));
}
}
TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
absl::string_view hlo_string = R"(
HloModule CondAllocation, is_scheduled=true
true_computation2 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg1 = f32[3]{0} negate(gte)
}
false_computation2 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg2 = f32[3]{0} negate(gte)
}
true_computation1 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
slice = f32[1]{0} slice(gte), slice={[0:1]}
bitcast = f32[] bitcast(slice)
constant = f32[] constant(0.0)
compare = pred[] compare(bitcast, constant), direction=GT
ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2
}
false_computation1 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg3 = f32[3]{0} negate(gte)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy = f32[3]{0} copy(p0)
tuple = (f32[3]{0}) tuple(copy)
ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
if (GetParam()) {
// Make sure alternate memory allocation gets propagated into both levels of
// conditional.
auto copy =
module->GetComputationWithName("entry")->GetInstructionWithName("copy");
EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
auto neg1_operand = module->GetComputationWithName("true_computation2")
->GetInstructionWithName("neg1")
->operand(0);
auto neg2_operand = module->GetComputationWithName("false_computation2")
->GetInstructionWithName("neg2")
->operand(0);
auto neg3_operand = module->GetComputationWithName("false_computation1")
->GetInstructionWithName("neg3")
->operand(0);
EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
EXPECT_EQ(neg3_operand->shape().layout().memory_space(),
kAlternateMemorySpace);
}
}
TEST_P(MemorySpaceAssignmentTest,
RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
// Ensure that request identifier returned by Send/Recv HLOs are not allocated
@ -2149,7 +2467,8 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
AssignMemorySpace(module.get(), -1, 5);
}
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
// TODO(berkin): This might be an incorrect input graph, investigate.
TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});