[XLA] Implement memory space allocation across sequential calls (e.g. while).

For now, the heuristics aren't very good and also need to allow moving the
buffer between memory spaces inside the while body as well.

PiperOrigin-RevId: 281575698
Change-Id: I7c50a4ea4001021e0de44ff7643cc7a7cd44d7bd
This commit is contained in:
Berkin Ilbeyi 2019-11-20 12:25:27 -08:00 committed by TensorFlower Gardener
parent 9401effbed
commit 162048befe
4 changed files with 312 additions and 108 deletions

View File

@ -83,6 +83,12 @@ class HloLiveRange {
return buffer_live_ranges_;
}
// Returns the map from a computation and its time span in the schedule.
const absl::flat_hash_map<const HloComputation*, TimeBound>&
computation_span_times() const {
return computation_span_times_;
}
// Returns the time stamp of the end of the program.
LogicalTime schedule_end_time() const { return schedule_end_time_; }

View File

@ -237,24 +237,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
auto colocated_intervals = GetSortedColocatedIntervals(interval);
bool keep_in_default_memory = false;
for (const BufferInterval* colocated_interval : colocated_intervals) {
const HloValue* value = colocated_interval->buffer;
// If any of the colocated values are phi buffers, we keep them in the
// default memory for now.
if (value->is_phi()) {
keep_in_default_memory = true;
VLOG(4) << "Keeping value " << value->ToShortString()
<< " because it contains a phi node.";
break;
}
if (colocated_intervals.size() > 1 &&
!options_.allocate_across_sequential_calls) {
VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
<< " because it aliases with another interval and "
<< " allocate_across_sequential_calls is false.";
continue;
}
// At this point, none of the colocated buffers contain any phi buffers.
const HloComputation* defining_computation =
colocated_intervals[0]->buffer->defining_instruction()->parent();
MemorySpaceAssignment::Allocation* aliased_allocation = nullptr;
for (const BufferInterval* colocated_interval : colocated_intervals) {
if (keep_in_default_memory) {
break;
}
const HloValue* value = colocated_interval->buffer;
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
@ -267,25 +262,66 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
return instruction_schedule.at(use1.instruction) <
instruction_schedule.at(use2.instruction);
});
// If there was an aliased allocation for this buffer, propagate that for
// this HloValue.
if (aliased_allocation != nullptr) {
VLOG(3) << "Adding an aliased allocation: ("
<< aliased_allocation->start_time() << ", "
<< aliased_allocation->end_time()
<< ") pos: " << aliased_allocation->defining_position()
<< " mem space: "
<< (aliased_allocation->memory_space() == MemorySpace::kDefault
? "default"
: "alt");
allocation_sequence->push_back(
absl::make_unique<MemorySpaceAssignment::Allocation>(
value->defining_instruction(), value->defining_position(),
aliased_allocation->memory_space(), aliased_allocation->chunk(),
aliased_allocation->start_time(),
aliased_allocation->end_time()));
}
// Iterate over the uses.
for (HloUse use : uses) {
int64 use_time = instruction_schedule.at(use.instruction);
int64 last_use_time = instruction_schedule.at(uses.back().instruction);
int64 latest_prefetch_time = use_time;
if (use.instruction->parent() != defining_computation) {
VLOG(3) << "skip use " << use.ToString()
<< " because it's in a different computation.";
continue;
}
// Sequential calls include kWhile, kCall, and kConditional opcodes.
bool is_sequential_call =
(GetInstructionCallContext(use.instruction->opcode()) ==
CallContext::kSequential);
if (is_sequential_call) {
for (const HloComputation* called_computation :
use.instruction->called_computations()) {
const HloLiveRange::TimeBound& computation_span =
hlo_live_range_.computation_span_times().at(called_computation);
latest_prefetch_time =
std::min(computation_span.start, latest_prefetch_time);
}
}
// Bitcasts don't define buffers and don't directly consume buffers.
// Skip allocating buffers for bitcast uses. The uses that feed from
// bitcasts will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
if (!FindAllocation(definition_time, use_time, last_use_time,
value->defining_position(), use, value,
colocated_interval->size, allocation_sequence)) {
latest_prefetch_time, value->defining_position(),
use, value, colocated_interval->size,
allocation_sequence)) {
// If the allocation finding failed (e.g., due to running out of
// asynchronous copies), then fall back to allocating the buffer
// entirely in the default memory.
pending_chunks_.clear();
pending_async_copies_.clear();
allocation_sequence->clear();
keep_in_default_memory = true;
break;
}
@ -293,6 +329,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
// allocation already at the alternate memory.
definition_time = use_time;
}
// If the use has been a sequential call (e.g. a while loop), the other
// colocated intervals must alias with this allocation.
if (is_sequential_call && !allocation_sequence->empty()) {
aliased_allocation = allocation_sequence->back().get();
}
}
}
@ -390,8 +432,9 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks(
bool AlternateMemoryBestFitHeap::FindAllocation(
int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use, const HloValue* buffer,
int64 size, MemorySpaceAssignment::AllocationSequence* allocations) {
int64 latest_prefetch_time, HloPosition defining_position, HloUse use,
const HloValue* buffer, int64 size,
MemorySpaceAssignment::AllocationSequence* allocations) {
HloInstruction* operand =
use.instruction->mutable_operand(use.operand_number);
// If the operand is a bitcast, we look at bitcast's operand until we find a
@ -408,8 +451,10 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
alternate_mem_interval.end = end_time;
VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
<< start_time << ", " << end_time << ") last use = " << last_use_time
<< " use = " << use.ToString() << ". Size = " << size
<< start_time << ", " << end_time
<< ") latest prefetch = " << latest_prefetch_time
<< " last use = " << last_use_time << " use = " << use.ToString()
<< ". Size = " << size
<< ", def pos = " << defining_position.ToString()
<< ", operand = " << operand->ToShortString()
<< (non_bitcast_operand != operand
@ -445,19 +490,6 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
}
}
// TODO(berkin): This is curently overly restrictive and will fail using
// alternate memory for any buffer that might leak into a different
// computation (e.g., while body). Enable more usage of alternate memory
// across computations.
if (defining_position.instruction->parent() != use.instruction->parent() ||
(!use.instruction->called_computations().empty() &&
use.instruction->opcode() != HloOpcode::kFusion)) {
VLOG(3) << "Use is in a different computation or calls a computation.";
// Fail because we do not allow asynchronous copies while in the bodies of
// other computation.
return false;
}
// First try keeping the allocation entirely in the alternate memory.
if (!definition_requires_buffer_in_default_mem &&
!use_requires_buffer_in_default_mem &&
@ -491,7 +523,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
prev_allocation->end_time())) {
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
prev_allocation->start_time(), prev_allocation->end_time(),
allocations);
prev_allocation->end_time(), allocations);
} else {
VLOG(3) << "This violates the maximum async copies.";
@ -504,7 +536,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) {
VLOG(3) << "Eviction successful.";
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
time, time, allocations);
time, time, time, allocations);
eviction_scheduled = true;
break;
}
@ -558,7 +590,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
// ^ ^
// Copy Copy
// Start Done
options_.prefetch_interval_picker->Begin(use, start_time, end_time);
options_.prefetch_interval_picker->Begin(use, start_time,
latest_prefetch_time);
while (!options_.prefetch_interval_picker->Done()) {
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
VLOG(4) << "Trying alternate memory allocation ("
@ -583,7 +616,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate,
chunk_candidate.chunk, alternate_mem_interval.start,
end_time, allocations);
end_time, latest_prefetch_time, allocations);
allocations->back()->AddUse(use);
return true;
@ -598,16 +631,19 @@ bool AlternateMemoryBestFitHeap::FindAllocation(
void AlternateMemoryBestFitHeap::AddAsyncCopy(
const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time,
int64 copy_done_schedule_before_time,
MemorySpaceAssignment::AllocationSequence* allocations) {
VLOG(3) << "Copy to "
<< (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
? "default"
: "alternate")
<< " memory between " << start_time << " and " << end_time;
<< " memory between " << start_time << " and "
<< copy_done_schedule_before_time << " keeping until " << end_time;
allocations->push_back(
absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
prev_allocation, memory_space, chunk, start_time, end_time));
prev_allocation, memory_space, chunk, start_time, end_time,
copy_done_schedule_before_time));
// Register the additional async copy with the interval tree to keep track of
// the limit at any given time.
@ -828,9 +864,12 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) {
&memory_space_assignment.allocation_map_, options, *alias_analysis,
*hlo_live_range);
HeapSimulator::Options heap_simulator_options;
heap_simulator_options.may_reuse_operand_buffers = false;
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
module->schedule(),
*alias_analysis.get(), options.size_fn)
*alias_analysis.get(), options.size_fn,
heap_simulator_options)
.status());
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
@ -1221,28 +1260,30 @@ Status MemorySpaceAssignment::FixSchedule() {
instruction_index <
flattened_instruction_sequence_.instructions().size();
++instruction_index) {
HloInstruction* instruction =
flattened_instruction_sequence_.instructions()[instruction_index];
if (instruction->parent() != computation) {
continue;
}
auto insts_before_iter = schedule_before_.find(instruction_index);
if (insts_before_iter != schedule_before_.end()) {
for (HloInstruction* new_instruction : insts_before_iter->second) {
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
&inserted_instructions);
if (new_instruction->parent() == computation) {
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
&inserted_instructions);
}
}
}
HloInstruction* instruction =
flattened_instruction_sequence_.instructions()[instruction_index];
// Insert only if not previously inserted.
if (!inserted_instructions.contains(instruction)) {
if (!inserted_instructions.contains(instruction) &&
instruction->parent() == computation) {
EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
&inserted_instructions);
}
auto insts_after_iter = schedule_after_.find(instruction_index);
if (insts_after_iter != schedule_after_.end()) {
for (HloInstruction* new_instruction : insts_after_iter->second) {
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
&inserted_instructions);
if (new_instruction->parent() == computation) {
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
&inserted_instructions);
}
}
}
}

View File

@ -268,6 +268,10 @@ class MemorySpaceAssignment {
// Specifies the upper bound for number of outstanding asynchronous copies,
// -1 for unlimited.
int64 max_outstanding_async_copies = -1;
// If true, tries allocating buffers across (e.g., before and inside a while
// loop body) sequential calls (kWhile, kCall, and kConditional).
bool allocate_across_sequential_calls = false;
};
// This class represents an allocation that might either be in the default or
@ -363,13 +367,14 @@ class MemorySpaceAssignment {
class CopyAllocation : public Allocation {
public:
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
Chunk chunk, int64 start_time, int64 end_time)
Chunk chunk, int64 start_time, int64 end_time,
int64 copy_done_schedule_before_time)
: Allocation(/*instruction=*/nullptr,
/*defining_position=*/{nullptr, {}}, memory_space, chunk,
start_time, end_time),
prev_allocation_(prev_allocation),
copy_start_schedule_after_(start_time),
copy_done_schedule_before_(end_time) {}
copy_done_schedule_before_(copy_done_schedule_before_time) {}
bool is_copy_allocation() const override { return true; }
@ -525,8 +530,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// allocations can be in default or alternate memory spaces, or can be
// prefetches or evictions. Returns true if successful.
bool FindAllocation(int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use,
const HloValue* buffer, int64 size,
int64 latest_prefetch_time, HloPosition defining_position,
HloUse use, const HloValue* buffer, int64 size,
MemorySpaceAssignment::AllocationSequence* allocations);
// Try allocating in alternate memory without any copies. Returns true if
@ -560,7 +565,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// Adds an asynchronous copy to the allocations.
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time,
int64 end_time,
int64 end_time, int64 copy_done_schedule_before_time,
MemorySpaceAssignment::AllocationSequence* allocations);
// These methods are used for delaying committing the chunk candidate until

View File

@ -35,7 +35,8 @@ int64 ShapeSize(const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape, kPointerSize);
}
class MemorySpaceAssignmentTest : public HloTestBase {
class MemorySpaceAssignmentTest : public HloTestBase,
public ::testing::WithParamInterface<bool> {
protected:
// We use the following two memory space values to describe the default (slow
// and large) and alternate (fast and small) memory spaces.
@ -105,6 +106,7 @@ class MemorySpaceAssignmentTest : public HloTestBase {
options.size_fn = size_fn;
options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
options.max_outstanding_async_copies = max_outstanding_async_copies;
options.allocate_across_sequential_calls = GetParam();
std::unique_ptr<PresetAssignments> preset_assignments =
MemorySpaceAssignment::Run(module, options).ValueOrDie();
CheckPresetAssignments(preset_assignments.get());
@ -190,7 +192,7 @@ class MemorySpaceAssignmentTest : public HloTestBase {
}
};
TEST_F(MemorySpaceAssignmentTest, ParameterOnly) {
TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
// A module consisting of a single parameter. Inputs/outputs are currently
// excluded from memory space assignment.
HloComputation::Builder builder(TestName());
@ -210,7 +212,7 @@ TEST_F(MemorySpaceAssignmentTest, ParameterOnly) {
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
}
TEST_F(MemorySpaceAssignmentTest, Simple) {
TEST_P(MemorySpaceAssignmentTest, Simple) {
// A simple module with a few simple instructions. Expect this to be
// transformed with CopyStart and CopyDone instructions inserted after inputs
// and before outputs.
@ -256,7 +258,7 @@ TEST_F(MemorySpaceAssignmentTest, Simple) {
preset_assignments->chunks()[1].second.offset);
}
TEST_F(MemorySpaceAssignmentTest, NegateChain) {
TEST_P(MemorySpaceAssignmentTest, NegateChain) {
// The negate chain is long enough for asynchronous copy to be inserted
// between p1 and add.
HloComputation::Builder builder(TestName());
@ -319,7 +321,7 @@ TEST_F(MemorySpaceAssignmentTest, NegateChain) {
EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
}
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) {
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) {
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
AssignMemorySpace(module.get());
@ -330,12 +332,9 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) {
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::AsyncCopy(kDefaultMemorySpace,
kAlternateMemorySpace, op::Tanh()))));
EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module),
2);
}
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
@ -344,7 +343,7 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
0);
}
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
@ -353,7 +352,16 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
1);
}
TEST_F(MemorySpaceAssignmentTest, While) {
TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) {
std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2);
EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module),
2);
}
TEST_P(MemorySpaceAssignmentTest, While) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
@ -429,14 +437,18 @@ TEST_F(MemorySpaceAssignmentTest, While) {
AssignMemorySpace(module.get());
// Ensure the tuple value and buffers used in the while instruction are
// exempted from using the alternate memory. However, body_data_mul is
// independent and can be safely be placed in the alternate memory.
EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(data, op::ShapeWithLayout(shape));
EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
// exempted from using the alternate memory when allocating across sequential
// calls is disabled. However, body_data_mul is independent and can be safely
// be placed in the alternate memory.
const bool allocate_across_sequential_calls = GetParam();
if (!allocate_across_sequential_calls) {
EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(data, op::ShapeWithLayout(shape));
EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
}
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
@ -444,7 +456,7 @@ TEST_F(MemorySpaceAssignmentTest, While) {
EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
}
TEST_F(MemorySpaceAssignmentTest, Tuple) {
TEST_P(MemorySpaceAssignmentTest, Tuple) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
@ -499,7 +511,7 @@ TEST_F(MemorySpaceAssignmentTest, Tuple) {
op::GetTupleElement(op::GetTupleElement()))));
}
TEST_F(MemorySpaceAssignmentTest, Bitcast) {
TEST_P(MemorySpaceAssignmentTest, Bitcast) {
// Bitcasts can cause the position in the alternate memory to appear multiple
// times in the preset assignments. This test ensure the preset assignments
// refer to unique positions.
@ -528,7 +540,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast) {
EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
}
TEST_F(MemorySpaceAssignmentTest, Bitcast2) {
TEST_P(MemorySpaceAssignmentTest, Bitcast2) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape param_shape = ShapeUtil::MakeShape(F32, {6});
@ -564,7 +576,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast2) {
EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
}
TEST_F(MemorySpaceAssignmentTest, Bitcast3) {
TEST_P(MemorySpaceAssignmentTest, Bitcast3) {
HloComputation::Builder builder(TestName());
Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
@ -627,7 +639,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast3) {
EXPECT_EQ(bitcast4->shape().layout().memory_space(), kAlternateMemorySpace);
}
TEST_F(MemorySpaceAssignmentTest, BitcastTuple) {
TEST_P(MemorySpaceAssignmentTest, BitcastTuple) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape param_shape = ShapeUtil::MakeShape(F32, {6});
@ -678,7 +690,7 @@ TEST_F(MemorySpaceAssignmentTest, BitcastTuple) {
AssignMemorySpace(module.get());
}
TEST_F(MemorySpaceAssignmentTest, LastUseOpt) {
TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
// Test that checks the last use optimization. It uses two buffers that should
// be placed in alternate memory.
//
@ -735,7 +747,7 @@ TEST_F(MemorySpaceAssignmentTest, LastUseOpt) {
op::Add(op::Parameter(0), op::Parameter(0)))));
}
TEST_F(MemorySpaceAssignmentTest, CopyOrdering) {
TEST_P(MemorySpaceAssignmentTest, CopyOrdering) {
// Test to make sure the CopyStarts follow the same CopyDone order. The shapes
// are picked in increasing order to exploit the fact that heap simulator
// processes larger tensors first. This checks the ability of the compiler to
@ -850,7 +862,7 @@ TEST_F(MemorySpaceAssignmentTest, CopyOrdering) {
}
}
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
// Test to ensure CopyStart/CopyDone is placed only in the entry computation.
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
@ -934,7 +946,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
AssignMemorySpace(module.get(), -1, 50);
}
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
@ -1005,7 +1017,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
AssignMemorySpace(module.get(), -1, 5);
}
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
@ -1071,7 +1083,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
AssignMemorySpace(module.get(), -1, 5);
}
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
@ -1144,7 +1156,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
AssignMemorySpace(module.get(), -1, 5);
}
TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
// This test reproduces the failure in b/143288178. Given a graph like the
// following:
//
@ -1242,7 +1254,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloInstruction* while_data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_op, 0));
HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1));
HloInstruction* root =
builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
HloComputation* entry_computation =
@ -1265,7 +1277,143 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
AssignMemorySpace(module.get(), -1, 20);
}
TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape});
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_limit = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_limit, ComparisonDirection::kLt));
HloComputation* cond_computation =
module->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloInstruction* body_iter = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
HloInstruction* body_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
HloInstruction* body_negate0 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data));
HloInstruction* body_negate1 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0));
HloInstruction* body_negate2 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1));
HloInstruction* body_negate3 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2));
HloInstruction* body_negate4 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3));
HloInstruction* body_negate5 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4));
HloInstruction* body_negate6 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5));
HloInstruction* body_negate7 = body_builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6));
HloInstruction* body_iter_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
HloInstruction* body_iter_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
HloInstruction* body_out = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7}));
HloComputation* body_computation =
module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param_data"));
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
HloInstruction* negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
HloInstruction* negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
HloInstruction* negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
HloInstruction* negate3 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
HloInstruction* negate4 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
HloInstruction* negate5 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
HloInstruction* negate6 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
HloInstruction* negate7 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
HloInstruction* tuple = builder.AddInstruction(
HloInstruction::CreateTuple({data, iter, negate7}));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloInstruction* while_data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_op, 0));
HloInstruction* while_data2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_op, 2));
HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, while_data, while_data2));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_limit, cond_lt});
schedule.set_sequence(
body_computation,
{body_param, body_iter, body_data, body_negate0, body_negate1,
body_negate2, body_negate3, body_negate4, body_negate5, body_negate6,
body_negate7, body_iter_increment, body_iter_next, body_out});
schedule.set_sequence(
entry_computation,
{iter, data, negate0, negate1, negate2, negate3, negate4, negate5,
negate6, negate7, tuple, while_op, while_data, while_data2, root});
TF_CHECK_OK(module->set_schedule(schedule));
// Pick a large max prefetch interval to ensure all the while inputs are
// allocated in the alternate memory.
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
/*max_prefetch_interval=*/25);
int64 memory_space_across_while = kDefaultMemorySpace;
bool allocate_across_sequential_calls = GetParam();
if (allocate_across_sequential_calls) {
memory_space_across_while = kAlternateMemorySpace;
}
// Index {0} of the while loop argument is not written inside the while loop,
// so it can be trivially placed in the alternate memory space.
*ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
LayoutUtil::MakeLayout(
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
// Indexes {1} and {2} of the while loop argument are only placed in the
// alternate memory if we enable the allocate_across_sequential_calls option.
*ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() =
LayoutUtil::MakeLayout(
/*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
memory_space_across_while);
*ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
LayoutUtil::MakeLayout(
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
memory_space_across_while);
// Expect the layout for the while loop and its aliased buffers.
EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape));
}
TEST_P(MemorySpaceAssignmentTest, DanglingCopy) {
// This situation was encountered in vss, where there is a mismatch in the
// memory space in preset assignments and the output graph.
HloComputation::Builder builder(TestName());
@ -1311,7 +1459,7 @@ TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
AssignMemorySpace(module.get());
}
TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) {
TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
@ -1348,7 +1496,7 @@ TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) {
AssignMemorySpace(module.get());
}
TEST_F(MemorySpaceAssignmentTest, TupleInput) {
TEST_P(MemorySpaceAssignmentTest, TupleInput) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
@ -1388,7 +1536,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleInput) {
AssignMemorySpace(module.get());
}
TEST_F(MemorySpaceAssignmentTest, TupleToTuple1) {
TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
@ -1467,7 +1615,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple1) {
op::GetTupleElement(op::Fusion(), 1)))));
}
TEST_F(MemorySpaceAssignmentTest, TupleToTuple2) {
TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
@ -1547,7 +1695,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple2) {
op::GetTupleElement(op::Fusion(), 1), 1))))));
}
TEST_F(MemorySpaceAssignmentTest, TupleToTuple3) {
TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
@ -1594,7 +1742,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple3) {
EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
}
TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) {
TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
@ -1649,7 +1797,7 @@ TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) {
kDefaultMemorySpace);
}
TEST_F(MemorySpaceAssignmentTest, CostAnalysis) {
TEST_P(MemorySpaceAssignmentTest, CostAnalysis) {
// This is mostly a smoke test since it's difficult and brittle to work out
// the cost of the HLO instructions.
HloComputation::Builder builder(TestName());
@ -1701,7 +1849,7 @@ TEST_F(MemorySpaceAssignmentTest, CostAnalysis) {
EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
}
TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
// This test is carefully crafted to force only negates to be allocated to the
// alternate memory. The graph consists of interleaving negate and tanh
// operations:
@ -1762,16 +1910,16 @@ TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
F32, {4, 6},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kDefaultMemorySpace);
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {4, 6},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
// Expect only negates to be in alternate memory space.
EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
// Expect only negates to be in alternate memory space. Not all might fit but
// make sure at least one does.
std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2,
negate3, negate4};
int64 num_negates_in_alternate_mem = absl::c_count_if(
negate_instructions, [&](const HloInstruction* instruction) {
return instruction->shape().layout().memory_space() ==
kAlternateMemorySpace;
});
EXPECT_GE(num_negates_in_alternate_mem, 1);
EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
@ -1779,5 +1927,9 @@ TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
}
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
MemorySpaceAssignmentTest,
::testing::Values(false, true));
} // namespace
} // namespace xla