diff --git a/tensorflow/compiler/xla/service/hlo_live_range.h b/tensorflow/compiler/xla/service/hlo_live_range.h index cc0445acd1e..5464de696de 100644 --- a/tensorflow/compiler/xla/service/hlo_live_range.h +++ b/tensorflow/compiler/xla/service/hlo_live_range.h @@ -83,6 +83,12 @@ class HloLiveRange { return buffer_live_ranges_; } + // Returns the map from a computation and its time span in the schedule. + const absl::flat_hash_map& + computation_span_times() const { + return computation_span_times_; + } + // Returns the time stamp of the end of the program. LogicalTime schedule_end_time() const { return schedule_end_time_; } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index a9d4c1e7029..5e8b75ee50e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -237,24 +237,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } auto colocated_intervals = GetSortedColocatedIntervals(interval); - bool keep_in_default_memory = false; - for (const BufferInterval* colocated_interval : colocated_intervals) { - const HloValue* value = colocated_interval->buffer; - // If any of the colocated values are phi buffers, we keep them in the - // default memory for now. - if (value->is_phi()) { - keep_in_default_memory = true; - VLOG(4) << "Keeping value " << value->ToShortString() - << " because it contains a phi node."; - break; - } + + if (colocated_intervals.size() > 1 && + !options_.allocate_across_sequential_calls) { + VLOG(4) << "Not allocating " << interval.buffer->ToShortString() + << " because it aliases with another interval and " + << " allocate_across_sequential_calls is false."; + continue; } - // At this point, none of the colocated buffers contain any phi buffers. + const HloComputation* defining_computation = + colocated_intervals[0]->buffer->defining_instruction()->parent(); + MemorySpaceAssignment::Allocation* aliased_allocation = nullptr; for (const BufferInterval* colocated_interval : colocated_intervals) { - if (keep_in_default_memory) { - break; - } const HloValue* value = colocated_interval->buffer; const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); MemorySpaceAssignment::AllocationSequence* allocation_sequence = @@ -267,25 +262,66 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { return instruction_schedule.at(use1.instruction) < instruction_schedule.at(use2.instruction); }); + + // If there was an aliased allocation for this buffer, propagate that for + // this HloValue. + if (aliased_allocation != nullptr) { + VLOG(3) << "Adding an aliased allocation: (" + << aliased_allocation->start_time() << ", " + << aliased_allocation->end_time() + << ") pos: " << aliased_allocation->defining_position() + << " mem space: " + << (aliased_allocation->memory_space() == MemorySpace::kDefault + ? "default" + : "alt"); + allocation_sequence->push_back( + absl::make_unique( + value->defining_instruction(), value->defining_position(), + aliased_allocation->memory_space(), aliased_allocation->chunk(), + aliased_allocation->start_time(), + aliased_allocation->end_time())); + } + // Iterate over the uses. for (HloUse use : uses) { int64 use_time = instruction_schedule.at(use.instruction); int64 last_use_time = instruction_schedule.at(uses.back().instruction); + int64 latest_prefetch_time = use_time; + + if (use.instruction->parent() != defining_computation) { + VLOG(3) << "skip use " << use.ToString() + << " because it's in a different computation."; + continue; + } + + // Sequential calls include kWhile, kCall, and kConditional opcodes. + bool is_sequential_call = + (GetInstructionCallContext(use.instruction->opcode()) == + CallContext::kSequential); + if (is_sequential_call) { + for (const HloComputation* called_computation : + use.instruction->called_computations()) { + const HloLiveRange::TimeBound& computation_span = + hlo_live_range_.computation_span_times().at(called_computation); + latest_prefetch_time = + std::min(computation_span.start, latest_prefetch_time); + } + } // Bitcasts don't define buffers and don't directly consume buffers. // Skip allocating buffers for bitcast uses. The uses that feed from // bitcasts will be handled specially. if (use.instruction->opcode() != HloOpcode::kBitcast) { if (!FindAllocation(definition_time, use_time, last_use_time, - value->defining_position(), use, value, - colocated_interval->size, allocation_sequence)) { + latest_prefetch_time, value->defining_position(), + use, value, colocated_interval->size, + allocation_sequence)) { // If the allocation finding failed (e.g., due to running out of // asynchronous copies), then fall back to allocating the buffer // entirely in the default memory. pending_chunks_.clear(); pending_async_copies_.clear(); allocation_sequence->clear(); - keep_in_default_memory = true; break; } @@ -293,6 +329,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // allocation already at the alternate memory. definition_time = use_time; } + + // If the use has been a sequential call (e.g. a while loop), the other + // colocated intervals must alias with this allocation. + if (is_sequential_call && !allocation_sequence->empty()) { + aliased_allocation = allocation_sequence->back().get(); + } } } @@ -390,8 +432,9 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks( bool AlternateMemoryBestFitHeap::FindAllocation( int64 start_time, int64 end_time, int64 last_use_time, - HloPosition defining_position, HloUse use, const HloValue* buffer, - int64 size, MemorySpaceAssignment::AllocationSequence* allocations) { + int64 latest_prefetch_time, HloPosition defining_position, HloUse use, + const HloValue* buffer, int64 size, + MemorySpaceAssignment::AllocationSequence* allocations) { HloInstruction* operand = use.instruction->mutable_operand(use.operand_number); // If the operand is a bitcast, we look at bitcast's operand until we find a @@ -408,8 +451,10 @@ bool AlternateMemoryBestFitHeap::FindAllocation( alternate_mem_interval.end = end_time; VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " (" - << start_time << ", " << end_time << ") last use = " << last_use_time - << " use = " << use.ToString() << ". Size = " << size + << start_time << ", " << end_time + << ") latest prefetch = " << latest_prefetch_time + << " last use = " << last_use_time << " use = " << use.ToString() + << ". Size = " << size << ", def pos = " << defining_position.ToString() << ", operand = " << operand->ToShortString() << (non_bitcast_operand != operand @@ -445,19 +490,6 @@ bool AlternateMemoryBestFitHeap::FindAllocation( } } - // TODO(berkin): This is curently overly restrictive and will fail using - // alternate memory for any buffer that might leak into a different - // computation (e.g., while body). Enable more usage of alternate memory - // across computations. - if (defining_position.instruction->parent() != use.instruction->parent() || - (!use.instruction->called_computations().empty() && - use.instruction->opcode() != HloOpcode::kFusion)) { - VLOG(3) << "Use is in a different computation or calls a computation."; - // Fail because we do not allow asynchronous copies while in the bodies of - // other computation. - return false; - } - // First try keeping the allocation entirely in the alternate memory. if (!definition_requires_buffer_in_default_mem && !use_requires_buffer_in_default_mem && @@ -491,7 +523,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( prev_allocation->end_time())) { AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, prev_allocation->start_time(), prev_allocation->end_time(), - allocations); + prev_allocation->end_time(), allocations); } else { VLOG(3) << "This violates the maximum async copies."; @@ -504,7 +536,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) { VLOG(3) << "Eviction successful."; AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, - time, time, allocations); + time, time, time, allocations); eviction_scheduled = true; break; } @@ -558,7 +590,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation( // ^ ^ // Copy Copy // Start Done - options_.prefetch_interval_picker->Begin(use, start_time, end_time); + options_.prefetch_interval_picker->Begin(use, start_time, + latest_prefetch_time); while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); VLOG(4) << "Trying alternate memory allocation (" @@ -583,7 +616,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate, chunk_candidate.chunk, alternate_mem_interval.start, - end_time, allocations); + end_time, latest_prefetch_time, allocations); allocations->back()->AddUse(use); return true; @@ -598,16 +631,19 @@ bool AlternateMemoryBestFitHeap::FindAllocation( void AlternateMemoryBestFitHeap::AddAsyncCopy( const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time, + int64 copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations) { VLOG(3) << "Copy to " << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault ? "default" : "alternate") - << " memory between " << start_time << " and " << end_time; + << " memory between " << start_time << " and " + << copy_done_schedule_before_time << " keeping until " << end_time; allocations->push_back( absl::make_unique( - prev_allocation, memory_space, chunk, start_time, end_time)); + prev_allocation, memory_space, chunk, start_time, end_time, + copy_done_schedule_before_time)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. @@ -828,9 +864,12 @@ MemorySpaceAssignment::Run(HloModule* module, const Options& options) { &memory_space_assignment.allocation_map_, options, *alias_analysis, *hlo_live_range); + HeapSimulator::Options heap_simulator_options; + heap_simulator_options.may_reuse_operand_buffers = false; TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, module->schedule(), - *alias_analysis.get(), options.size_fn) + *alias_analysis.get(), options.size_fn, + heap_simulator_options) .status()); TF_RETURN_IF_ERROR(memory_space_assignment.Process()); @@ -1221,28 +1260,30 @@ Status MemorySpaceAssignment::FixSchedule() { instruction_index < flattened_instruction_sequence_.instructions().size(); ++instruction_index) { - HloInstruction* instruction = - flattened_instruction_sequence_.instructions()[instruction_index]; - if (instruction->parent() != computation) { - continue; - } auto insts_before_iter = schedule_before_.find(instruction_index); if (insts_before_iter != schedule_before_.end()) { for (HloInstruction* new_instruction : insts_before_iter->second) { - EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, - &inserted_instructions); + if (new_instruction->parent() == computation) { + EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, + &inserted_instructions); + } } } + HloInstruction* instruction = + flattened_instruction_sequence_.instructions()[instruction_index]; // Insert only if not previously inserted. - if (!inserted_instructions.contains(instruction)) { + if (!inserted_instructions.contains(instruction) && + instruction->parent() == computation) { EnsureInstructionAndOperandsInserted(instruction, &new_sequence, &inserted_instructions); } auto insts_after_iter = schedule_after_.find(instruction_index); if (insts_after_iter != schedule_after_.end()) { for (HloInstruction* new_instruction : insts_after_iter->second) { - EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, - &inserted_instructions); + if (new_instruction->parent() == computation) { + EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, + &inserted_instructions); + } } } } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index b6cf5a0b4ff..bfc91664bea 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -268,6 +268,10 @@ class MemorySpaceAssignment { // Specifies the upper bound for number of outstanding asynchronous copies, // -1 for unlimited. int64 max_outstanding_async_copies = -1; + + // If true, tries allocating buffers across (e.g., before and inside a while + // loop body) sequential calls (kWhile, kCall, and kConditional). + bool allocate_across_sequential_calls = false; }; // This class represents an allocation that might either be in the default or @@ -363,13 +367,14 @@ class MemorySpaceAssignment { class CopyAllocation : public Allocation { public: CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, - Chunk chunk, int64 start_time, int64 end_time) + Chunk chunk, int64 start_time, int64 end_time, + int64 copy_done_schedule_before_time) : Allocation(/*instruction=*/nullptr, /*defining_position=*/{nullptr, {}}, memory_space, chunk, start_time, end_time), prev_allocation_(prev_allocation), copy_start_schedule_after_(start_time), - copy_done_schedule_before_(end_time) {} + copy_done_schedule_before_(copy_done_schedule_before_time) {} bool is_copy_allocation() const override { return true; } @@ -525,8 +530,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // allocations can be in default or alternate memory spaces, or can be // prefetches or evictions. Returns true if successful. bool FindAllocation(int64 start_time, int64 end_time, int64 last_use_time, - HloPosition defining_position, HloUse use, - const HloValue* buffer, int64 size, + int64 latest_prefetch_time, HloPosition defining_position, + HloUse use, const HloValue* buffer, int64 size, MemorySpaceAssignment::AllocationSequence* allocations); // Try allocating in alternate memory without any copies. Returns true if @@ -560,7 +565,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, Chunk chunk, int64 start_time, - int64 end_time, + int64 end_time, int64 copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations); // These methods are used for delaying committing the chunk candidate until diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 7ec4ddcb3d6..637259032da 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -35,7 +35,8 @@ int64 ShapeSize(const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, kPointerSize); } -class MemorySpaceAssignmentTest : public HloTestBase { +class MemorySpaceAssignmentTest : public HloTestBase, + public ::testing::WithParamInterface { protected: // We use the following two memory space values to describe the default (slow // and large) and alternate (fast and small) memory spaces. @@ -105,6 +106,7 @@ class MemorySpaceAssignmentTest : public HloTestBase { options.size_fn = size_fn; options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; options.max_outstanding_async_copies = max_outstanding_async_copies; + options.allocate_across_sequential_calls = GetParam(); std::unique_ptr preset_assignments = MemorySpaceAssignment::Run(module, options).ValueOrDie(); CheckPresetAssignments(preset_assignments.get()); @@ -190,7 +192,7 @@ class MemorySpaceAssignmentTest : public HloTestBase { } }; -TEST_F(MemorySpaceAssignmentTest, ParameterOnly) { +TEST_P(MemorySpaceAssignmentTest, ParameterOnly) { // A module consisting of a single parameter. Inputs/outputs are currently // excluded from memory space assignment. HloComputation::Builder builder(TestName()); @@ -210,7 +212,7 @@ TEST_F(MemorySpaceAssignmentTest, ParameterOnly) { EXPECT_THAT(p0, op::ShapeWithLayout(shape)); } -TEST_F(MemorySpaceAssignmentTest, Simple) { +TEST_P(MemorySpaceAssignmentTest, Simple) { // A simple module with a few simple instructions. Expect this to be // transformed with CopyStart and CopyDone instructions inserted after inputs // and before outputs. @@ -256,7 +258,7 @@ TEST_F(MemorySpaceAssignmentTest, Simple) { preset_assignments->chunks()[1].second.offset); } -TEST_F(MemorySpaceAssignmentTest, NegateChain) { +TEST_P(MemorySpaceAssignmentTest, NegateChain) { // The negate chain is long enough for asynchronous copy to be inserted // between p1 and add. HloComputation::Builder builder(TestName()); @@ -319,7 +321,7 @@ TEST_F(MemorySpaceAssignmentTest, NegateChain) { EXPECT_THAT(sequence.instructions()[10], op::CopyDone()); } -TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) { +TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) { std::unique_ptr module = CreateEvictAndPrefetchModule(); AssignMemorySpace(module.get()); @@ -330,12 +332,9 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) { op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace, op::Tanh())))); - - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 2); } -TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { +TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { std::unique_ptr module = CreateEvictAndPrefetchModule(); AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0); @@ -344,7 +343,7 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { 0); } -TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { +TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { std::unique_ptr module = CreateEvictAndPrefetchModule(); AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); @@ -353,7 +352,16 @@ TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { 1); } -TEST_F(MemorySpaceAssignmentTest, While) { +TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { + std::unique_ptr module = CreateEvictAndPrefetchModule(); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2); + + EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), + 2); +} + +TEST_P(MemorySpaceAssignmentTest, While) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); @@ -429,14 +437,18 @@ TEST_F(MemorySpaceAssignmentTest, While) { AssignMemorySpace(module.get()); // Ensure the tuple value and buffers used in the while instruction are - // exempted from using the alternate memory. However, body_data_mul is - // independent and can be safely be placed in the alternate memory. - EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape)); - EXPECT_THAT(data, op::ShapeWithLayout(shape)); - EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape)); - EXPECT_THAT(body_data, op::ShapeWithLayout(shape)); - EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape)); - EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape)); + // exempted from using the alternate memory when allocating across sequential + // calls is disabled. However, body_data_mul is independent and can be safely + // be placed in the alternate memory. + const bool allocate_across_sequential_calls = GetParam(); + if (!allocate_across_sequential_calls) { + EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape)); + EXPECT_THAT(data, op::ShapeWithLayout(shape)); + EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape)); + EXPECT_THAT(body_data, op::ShapeWithLayout(shape)); + EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape)); + EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape)); + } Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, @@ -444,7 +456,7 @@ TEST_F(MemorySpaceAssignmentTest, While) { EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem)); } -TEST_F(MemorySpaceAssignmentTest, Tuple) { +TEST_P(MemorySpaceAssignmentTest, Tuple) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape}); @@ -499,7 +511,7 @@ TEST_F(MemorySpaceAssignmentTest, Tuple) { op::GetTupleElement(op::GetTupleElement())))); } -TEST_F(MemorySpaceAssignmentTest, Bitcast) { +TEST_P(MemorySpaceAssignmentTest, Bitcast) { // Bitcasts can cause the position in the alternate memory to appear multiple // times in the preset assignments. This test ensure the preset assignments // refer to unique positions. @@ -528,7 +540,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast) { EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); } -TEST_F(MemorySpaceAssignmentTest, Bitcast2) { +TEST_P(MemorySpaceAssignmentTest, Bitcast2) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape param_shape = ShapeUtil::MakeShape(F32, {6}); @@ -564,7 +576,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast2) { EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); } -TEST_F(MemorySpaceAssignmentTest, Bitcast3) { +TEST_P(MemorySpaceAssignmentTest, Bitcast3) { HloComputation::Builder builder(TestName()); Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -627,7 +639,7 @@ TEST_F(MemorySpaceAssignmentTest, Bitcast3) { EXPECT_EQ(bitcast4->shape().layout().memory_space(), kAlternateMemorySpace); } -TEST_F(MemorySpaceAssignmentTest, BitcastTuple) { +TEST_P(MemorySpaceAssignmentTest, BitcastTuple) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape param_shape = ShapeUtil::MakeShape(F32, {6}); @@ -678,7 +690,7 @@ TEST_F(MemorySpaceAssignmentTest, BitcastTuple) { AssignMemorySpace(module.get()); } -TEST_F(MemorySpaceAssignmentTest, LastUseOpt) { +TEST_P(MemorySpaceAssignmentTest, LastUseOpt) { // Test that checks the last use optimization. It uses two buffers that should // be placed in alternate memory. // @@ -735,7 +747,7 @@ TEST_F(MemorySpaceAssignmentTest, LastUseOpt) { op::Add(op::Parameter(0), op::Parameter(0))))); } -TEST_F(MemorySpaceAssignmentTest, CopyOrdering) { +TEST_P(MemorySpaceAssignmentTest, CopyOrdering) { // Test to make sure the CopyStarts follow the same CopyDone order. The shapes // are picked in increasing order to exploit the fact that heap simulator // processes larger tensors first. This checks the ability of the compiler to @@ -850,7 +862,7 @@ TEST_F(MemorySpaceAssignmentTest, CopyOrdering) { } } -TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) { +TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) { // Test to ensure CopyStart/CopyDone is placed only in the entry computation. auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); @@ -934,7 +946,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) { AssignMemorySpace(module.get(), -1, 50); } -TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) { +TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); @@ -1005,7 +1017,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) { AssignMemorySpace(module.get(), -1, 5); } -TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { +TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); @@ -1071,7 +1083,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { AssignMemorySpace(module.get(), -1, 5); } -TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) { +TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); @@ -1144,7 +1156,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) { AssignMemorySpace(module.get(), -1, 5); } -TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) { +TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) { // This test reproduces the failure in b/143288178. Given a graph like the // following: // @@ -1242,7 +1254,7 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) { HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( tuple_shape, cond_computation, body_computation, tuple)); HloInstruction* while_data = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while_op, 0)); + HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1)); HloInstruction* root = builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub})); HloComputation* entry_computation = @@ -1265,7 +1277,143 @@ TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) { AssignMemorySpace(module.get(), -1, 20); } -TEST_F(MemorySpaceAssignmentTest, DanglingCopy) { +TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { + auto module = CreateNewVerifiedModule(); + Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); + Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_limit = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(50.f))); + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_limit, ComparisonDirection::kLt)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloInstruction* body_iter = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); + HloInstruction* body_data = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + HloInstruction* body_negate0 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data)); + HloInstruction* body_negate1 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0)); + HloInstruction* body_negate2 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1)); + HloInstruction* body_negate3 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2)); + HloInstruction* body_negate4 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3)); + HloInstruction* body_negate5 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4)); + HloInstruction* body_negate6 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5)); + HloInstruction* body_negate7 = body_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6)); + HloInstruction* body_iter_increment = body_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.f))); + HloInstruction* body_iter_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); + HloInstruction* body_out = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7})); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param_data")); + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_iter")); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* negate5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); + HloInstruction* negate6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); + HloInstruction* negate7 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({data, iter, negate7})); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloInstruction* while_data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_op, 0)); + HloInstruction* while_data2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_op, 2)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, while_data, while_data2)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_limit, cond_lt}); + schedule.set_sequence( + body_computation, + {body_param, body_iter, body_data, body_negate0, body_negate1, + body_negate2, body_negate3, body_negate4, body_negate5, body_negate6, + body_negate7, body_iter_increment, body_iter_next, body_out}); + schedule.set_sequence( + entry_computation, + {iter, data, negate0, negate1, negate2, negate3, negate4, negate5, + negate6, negate7, tuple, while_op, while_data, while_data2, root}); + TF_CHECK_OK(module->set_schedule(schedule)); + + // Pick a large max prefetch interval to ensure all the while inputs are + // allocated in the alternate memory. + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/25); + + int64 memory_space_across_while = kDefaultMemorySpace; + bool allocate_across_sequential_calls = GetParam(); + if (allocate_across_sequential_calls) { + memory_space_across_while = kAlternateMemorySpace; + } + + // Index {0} of the while loop argument is not written inside the while loop, + // so it can be trivially placed in the alternate memory space. + *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() = + LayoutUtil::MakeLayout( + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + // Indexes {1} and {2} of the while loop argument are only placed in the + // alternate memory if we enable the allocate_across_sequential_calls option. + *ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() = + LayoutUtil::MakeLayout( + /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0, + memory_space_across_while); + *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() = + LayoutUtil::MakeLayout( + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + memory_space_across_while); + + // Expect the layout for the while loop and its aliased buffers. + EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape)); + EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape)); + EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape)); + EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape)); + EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape)); +} + +TEST_P(MemorySpaceAssignmentTest, DanglingCopy) { // This situation was encountered in vss, where there is a mismatch in the // memory space in preset assignments and the output graph. HloComputation::Builder builder(TestName()); @@ -1311,7 +1459,7 @@ TEST_F(MemorySpaceAssignmentTest, DanglingCopy) { AssignMemorySpace(module.get()); } -TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) { +TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); @@ -1348,7 +1496,7 @@ TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) { AssignMemorySpace(module.get()); } -TEST_F(MemorySpaceAssignmentTest, TupleInput) { +TEST_P(MemorySpaceAssignmentTest, TupleInput) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); @@ -1388,7 +1536,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleInput) { AssignMemorySpace(module.get()); } -TEST_F(MemorySpaceAssignmentTest, TupleToTuple1) { +TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); @@ -1467,7 +1615,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple1) { op::GetTupleElement(op::Fusion(), 1))))); } -TEST_F(MemorySpaceAssignmentTest, TupleToTuple2) { +TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); @@ -1547,7 +1695,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple2) { op::GetTupleElement(op::Fusion(), 1), 1)))))); } -TEST_F(MemorySpaceAssignmentTest, TupleToTuple3) { +TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); @@ -1594,7 +1742,7 @@ TEST_F(MemorySpaceAssignmentTest, TupleToTuple3) { EXPECT_THAT(fusion1, op::Fusion(op::Fusion())); } -TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) { +TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); @@ -1649,7 +1797,7 @@ TEST_F(MemorySpaceAssignmentTest, InputOutputAlias) { kDefaultMemorySpace); } -TEST_F(MemorySpaceAssignmentTest, CostAnalysis) { +TEST_P(MemorySpaceAssignmentTest, CostAnalysis) { // This is mostly a smoke test since it's difficult and brittle to work out // the cost of the HLO instructions. HloComputation::Builder builder(TestName()); @@ -1701,7 +1849,7 @@ TEST_F(MemorySpaceAssignmentTest, CostAnalysis) { EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem)); } -TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { +TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { // This test is carefully crafted to force only negates to be allocated to the // alternate memory. The graph consists of interleaving negate and tanh // operations: @@ -1762,16 +1910,16 @@ TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { F32, {4, 6}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, kDefaultMemorySpace); - Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( - F32, {4, 6}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, - kAlternateMemorySpace); - // Expect only negates to be in alternate memory space. - EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); - EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); - EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem)); - EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem)); - EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem)); + // Expect only negates to be in alternate memory space. Not all might fit but + // make sure at least one does. + std::vector negate_instructions = {negate0, negate1, negate2, + negate3, negate4}; + int64 num_negates_in_alternate_mem = absl::c_count_if( + negate_instructions, [&](const HloInstruction* instruction) { + return instruction->shape().layout().memory_space() == + kAlternateMemorySpace; + }); + EXPECT_GE(num_negates_in_alternate_mem, 1); EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem)); EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem)); EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem)); @@ -1779,5 +1927,9 @@ TEST_F(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem)); } +INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, + MemorySpaceAssignmentTest, + ::testing::Values(false, true)); + } // namespace } // namespace xla