From 517e0274df92535497e1866caba282434d612bd3 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 13 Apr 2020 10:34:24 -0700 Subject: [PATCH] [XLA] Better support for while loops in memory space assignment. This CL makes changes to support the aliasing requirements of while loops. It introduces the AllocationValue object which is similar to HloValue. While HloValue may include multiple HloPositions across multiple Computations (that alias with each other), there is one AllocationValue object per non-trivial (excluding GTE, Tuple, and Bitcast) position. This data structure allows memory space assignment to visit positions and uses in each computation separately, and then propagate the aliased allocation decisions to other AllocationValue objects. Using this new data structure, memory space assignment now properly propagates aliased positions and uses using required assignments. This CL also introduces GetTransitiveColocationIntervals to HeapSimulator that returns colocated intervals instead of HloValues. Using this API, memory space assignment represents aliased logical times where AllocationValues must have the same buffer assignment (e.g. while loop body parameter and while loop output). PiperOrigin-RevId: 306260022 Change-Id: I169cc2a6ee78c5ad9b7bcc366e519f7c5750a4e2 --- .../compiler/xla/service/buffer_assignment.cc | 27 +- .../xla/service/buffer_assignment_test.cc | 7 - .../xla/service/memory_space_assignment.cc | 822 +++++++++++++----- .../xla/service/memory_space_assignment.h | 171 +++- .../service/memory_space_assignment_test.cc | 320 ++++++- 5 files changed, 1081 insertions(+), 266 deletions(-) diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 6cbd33053fc..67cdb081a91 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1392,23 +1392,22 @@ Status BufferAssigner::AssignPresetBuffers( } const HloAliasAnalysis& alias_analysis = assignment->alias_analysis(); - const HloDataflowAnalysis& dataflow_analysis = - alias_analysis.dataflow_analysis(); for (auto& position_and_chunk : preset_assignments_->chunks()) { - const HloPosition& position = position_and_chunk.first; - const HloValue& value = dataflow_analysis.GetUniqueValueAt( - position.instruction, position.index); - VLOG(3) << "Preset allocation for value: " << value.ToShortString(); - const HeapSimulator::Chunk& chunk = position_and_chunk.second; - auto preset_allocations_iter = preset_allocations.find(value.color()); - CHECK(preset_allocations_iter != preset_allocations.end()) - << "No preset value allocation for color " << value.color() << " for " - << value.ToShortString() << " found."; - preset_allocations_iter->second->AddAssignment(value, chunk.offset, - chunk.size); + const HloPosition& defining_position = position_and_chunk.first; + const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt( + defining_position.instruction, defining_position.index); + for (const HloValue* value : buffer.values()) { + VLOG(3) << "Preset allocation for value: " << value->ToShortString(); + const HeapSimulator::Chunk& chunk = position_and_chunk.second; + auto preset_allocations_iter = preset_allocations.find(value->color()); + CHECK(preset_allocations_iter != preset_allocations.end()) + << "No preset value allocation for color " << value->color() + << " for " << value->ToShortString() << " found."; + preset_allocations_iter->second->AddAssignment(*value, chunk.offset, + chunk.size); + } - const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value); assigned_buffers->insert(&buffer); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 13166e9a9e5..8a75a7b01a9 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -835,13 +835,6 @@ TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) { // Set only one preset assignment for while data and its aliases. auto preset_assignments = absl::make_unique(); preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40}); - preset_assignments->add_chunk({while_op, {1}}, {/*offset=*/100, /*size=*/40}); - preset_assignments->add_chunk({cond_param, {1}}, - {/*offset=*/100, /*size=*/40}); - preset_assignments->add_chunk({body_param, {1}}, - {/*offset=*/100, /*size=*/40}); - preset_assignments->add_chunk({body_data_next, {}}, - {/*offset=*/100, /*size=*/40}); preset_assignments->assignment_information_for_space(/*memory_space=*/1) ->size = 140; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 291d1bd4ed4..56170fa3e1d 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -194,7 +194,8 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { - const Shape& shape = use.instruction->operand(use.operand_number)->shape(); + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape); // Estimate the time we would save by having this op in alternate memory. @@ -255,6 +256,94 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( ", logical interval elapsed (s) = ", logical_interval_elapsed); } +std::string MemorySpaceAssignment::AllocationValue::ToString() const { + std::string out = absl::StrCat("computation = ", computation()->name()); + absl::StrAppend(&out, "\n position:\n"); + absl::StrAppend(&out, " ", defining_position_.ToString(), "\n"); + absl::StrAppend(&out, " uses:\n"); + for (const HloUse& use : uses_) { + absl::StrAppend(&out, " ", use.ToString(), "\n"); + } + return out; +} + +std::string MemorySpaceAssignment::AllocationValue::ToShortString() const { + return absl::StrCat("computation = ", computation()->name(), + ", position = ", defining_position_.ToString(), + ", value = ", value_->ToShortString()); +} + +void AlternateMemoryBestFitHeap::CreateAllocationValues( + const HloValue* value, std::vector* allocation_values) { + VLOG(3) << "Creating AllocationValues for: " << value->ToString(); + + // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast) + // positions. We create an AllocationValue object for each non-trivial + // position. And for each AllocationValue object, we create an + // AllocationSequence consisting of one or more Allocation objects.The reason + // why we exclude the trivial positions from AllocationValue is because + // Allocation objects have special support for tuples and bitcasts. + const absl::flat_hash_map& + instruction_schedule = hlo_live_range_.instruction_schedule(); + std::vector positions; + for (const HloPosition& position : value->positions()) { + const HloInstruction* instruction = position.instruction; + if (instruction->opcode() != HloOpcode::kGetTupleElement && + instruction->opcode() != HloOpcode::kTuple && + instruction->opcode() != HloOpcode::kBitcast) { + positions.push_back(position); + } + } + absl::c_stable_sort(positions, + [&](const HloPosition& pos1, const HloPosition& pos2) { + return instruction_schedule.at(pos1.instruction) < + instruction_schedule.at(pos2.instruction); + }); + + // Create an AllocationValue for each non-trivial position. + absl::flat_hash_set computations; + int beginning_idx = allocation_values->size(); + for (int i = 0; i < positions.size(); ++i) { + const HloPosition& position = positions.at(i); + allocation_values->emplace_back(value, position); + } + + std::vector uses(value->uses()); + absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) { + return instruction_schedule.at(use1.instruction) < + instruction_schedule.at(use2.instruction); + }); + + // Associate each use with an AllocationValue. Each AllocationValue contains a + // position and uses in the same computation. Furthermore, if the original + // HloValue had multiple non-trivial positions in the same computation, those + // will get their own AllocationValue as well. We split these HloValues so + // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they + // point to the latest position. We then replace the operand of the use with + // CopyStart/CopyDone with an operand of the latest position. + for (const HloUse& use : uses) { + int64 use_time = instruction_schedule.at(use.instruction); + HloComputation* use_computation = use.instruction->parent(); + + AllocationValue* last_allocation_value = nullptr; + for (int i = beginning_idx; i < allocation_values->size(); ++i) { + AllocationValue* allocation_value = &allocation_values->at(i); + if (allocation_value->computation() == use_computation && + instruction_schedule.at( + allocation_value->defining_position().instruction) < use_time) { + last_allocation_value = allocation_value; + } + } + CHECK(last_allocation_value != nullptr); + last_allocation_value->AddUse(use, use_time); + } + + for (int i = beginning_idx; i < allocation_values->size(); ++i) { + VLOG(3) << "Created allocation value: " + << allocation_values->at(i).ToString(); + } +} + std::vector AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { @@ -297,9 +386,9 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( // The semantics of TupleSelect are weird: TupleSelect doesn't define a // buffer, but just forwards the buffers in the either left or right side. - // This means the the two different inputs to TupleSelect must not alias, yet - // they should be allocated in the same memory space, and both buffers must be - // kept alive for the entire live range of TupleSelect. Instead, just don't + // This means the two different inputs to TupleSelect must not alias, yet they + // should be allocated in the same memory space, and both buffers must be kept + // alive for the entire live range of TupleSelect. Instead, just don't // allocate TupleSelect in the alternate memory space. // TODO(berkin): Not allocating add-dependencies either since they need to be // treated specially. We should revisit this later. @@ -337,6 +426,76 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( return true; } +bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( + const HloUse& use) const { + if (use.instruction->opcode() == HloOpcode::kWhile) { + HloComputation* while_body = use.instruction->while_body(); + + // We don't want to allocate this buffer in alternate memory if it will be + // evicted anyway. Find out if it has an early use or a late definition that + // would make sense to keep it in the alternate memory. + HloValue* parameter_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + while_body->parameter_instruction(0), use.operand_index); + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + int64 parameter_time = + instruction_schedule.at(while_body->parameter_instruction(0)); + int64 root_time = instruction_schedule.at(while_body->root_instruction()); + int64 min_use_time = root_time; + for (const HloUse& parameter_use : parameter_value->uses()) { + if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement && + parameter_use.instruction->opcode() != HloOpcode::kTuple && + parameter_use.instruction->opcode() != HloOpcode::kBitcast) { + min_use_time = std::min( + min_use_time, instruction_schedule.at(parameter_use.instruction)); + } + } + // If there is no use of this buffer inside the while loop, there is no need + // to allocate it in the loop. + if (min_use_time == root_time) { + VLOG(4) << "While allocation not allowed in alternate memory. " + << "use time = " << min_use_time << ", root time = " << root_time; + return false; + } + HloValue* root_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + while_body->root_instruction(), use.operand_index); + int64 root_definition_time = + instruction_schedule.at(root_value->defining_instruction()); + const Shape& shape = root_value->shape(); + // Allow the buffer in alternate memory if the buffer has a short live range + // either at the beginning or end of the while loop body. + if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + shape, parameter_time, min_use_time) && + !options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + shape, root_definition_time, root_time)) { + VLOG(4) << "While allocation not allowed in alternate memory. " + << "use time = " << min_use_time + << ", def time = " << root_definition_time + << ", root time = " << root_time; + return false; + } + // Check if there is a required assignment for the while loop output. + HloValue* while_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + use.instruction, use.operand_index); + int64 while_time = instruction_schedule.at(use.instruction); + auto existing_required_assignment = + RequiredMemoryAssignmentAt(while_value, while_time); + if (existing_required_assignment) { + // TODO(berkin): Failing for now when the output is requested to be in + // alternate memory, and the buffer is a while loop output. + CHECK(existing_required_assignment->memory_space == MemorySpace::kDefault) + << "While loop buffers pinned to alternate memory not " + "currently supported."; + VLOG(4) << "While allocation not allowed in alternate memory because " + "there is a required default memory assignment."; + return false; + } + } + return true; +} + HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -346,6 +505,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { AddInputAndOutputRequiredAssignments(); + if (VLOG_IS_ON(4)) { + VLOG(4) << "Flattened instruction sequence:"; + const auto& instruction_sequence = + hlo_live_range_.flattened_instruction_sequence().instructions(); + for (int i = 0; i < instruction_sequence.size(); ++i) { + VLOG(4) << " " << i << ": " << instruction_sequence[i]->parent()->name() + << " " << instruction_sequence[i]->name(); + } + } + for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation) { continue; @@ -363,10 +532,18 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { absl::c_count(module->CrossProgramPrefetches(), std::make_pair(inst->parameter_number(), interval.buffer->index())) > 0) { + VLOG(3) << "Skip " << interval.buffer->ToShortString() + << " because it is cross-program prefetched."; continue; } auto colocated_intervals = GetSortedColocatedIntervals(interval); + // Create AllocationValues for all the + // colocated intervals. + std::vector allocation_values; + for (const auto& colocated_interval : colocated_intervals) { + CreateAllocationValues(colocated_interval->buffer, &allocation_values); + } if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { VLOG(4) << "Interval " << interval.buffer->ToShortString() @@ -409,56 +586,59 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { continue; } - const HloComputation* defining_computation = - colocated_intervals[0]->buffer->defining_instruction()->parent(); - MemorySpaceAssignment::Allocation* aliased_allocation = nullptr; + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + global_max_time_ = instruction_schedule.at( + module->entry_computation()->root_instruction()); + + // TODO(berkin): For now, place the phi values due to conditionals in + // default memory. for (const BufferInterval* colocated_interval : colocated_intervals) { const HloValue* value = colocated_interval->buffer; - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - allocation_sequence_list_->push_back({value, {}}); - MemorySpaceAssignment::AllocationSequence* allocation_sequence = - &allocation_sequence_list_->back().sequence; + for (const auto& position : value->positions()) { + if (position.instruction->opcode() == HloOpcode::kConditional) { + VLOG(3) << "Adding required assignment for condition output: " + << value->ToShortString(); + required_assignments_[value].push_back( + {MemorySpace::kDefault, + instruction_schedule.at(position.instruction), + /*chunk=*/absl::nullopt}); + for (const HloComputation* called_computation : + position.instruction->called_computations()) { + HloValue* root_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + called_computation->root_instruction(), position.index); + required_assignments_[root_value].push_back( + {MemorySpace::kDefault, + instruction_schedule.at( + called_computation->root_instruction()), + /*chunk=*/absl::nullopt}); + } + } + } + } + + // Data structure to contain the preferred offset for a given computation. + // We ensure that the same offset will be allocated outside the while loop + // as well as inside the while loop. + absl::flat_hash_map + preferred_offset_for_computation; + bool allocation_success = true; + for (auto& allocation_value : allocation_values) { int64 definition_time = - instruction_schedule.at(value->defining_instruction()); - // Sort the uses by the use time. - std::vector uses = value->uses(); - absl::c_stable_sort(uses, [&](HloUse use1, HloUse use2) { - return instruction_schedule.at(use1.instruction) < - instruction_schedule.at(use2.instruction); - }); + instruction_schedule.at(allocation_value.defining_instruction()); - // If there was an aliased allocation for this buffer, propagate that for - // this HloValue. - if (aliased_allocation != nullptr) { - VLOG(3) << "Adding an aliased allocation: (" - << aliased_allocation->start_time() << ", " - << aliased_allocation->end_time() - << ") pos: " << aliased_allocation->defining_position() - << " mem space: " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "default" - : "alt"); - allocation_sequence->push_back( - absl::make_unique( - value->defining_position(), aliased_allocation->memory_space(), - aliased_allocation->chunk(), definition_time, definition_time)); + absl::optional preferred_offset; + auto preferred_offset_it = + preferred_offset_for_computation.find(allocation_value.computation()); + if (preferred_offset_it != preferred_offset_for_computation.end()) { + preferred_offset = preferred_offset_it->second; } - std::vector use_times(uses.size()); - for (int i = 0; i < uses.size(); ++i) { - use_times[i] = instruction_schedule.at(uses[i].instruction); - } // Iterate over the uses. - for (HloUse use : uses) { + for (HloUse use : allocation_value.uses()) { int64 use_time = instruction_schedule.at(use.instruction); int64 latest_prefetch_time = use_time; - if (use.instruction->parent() != defining_computation) { - VLOG(3) << "skip use " << use.ToString() - << " because it's in a different computation."; - continue; - } - // Sequential calls include kWhile, kCall, and kConditional opcodes. bool is_sequential_call = (GetInstructionCallContext(use.instruction->opcode()) == @@ -471,6 +651,41 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { latest_prefetch_time = std::min(computation_span.start, latest_prefetch_time); } + if (use.instruction->opcode() == HloOpcode::kWhile) { + // Given an example while loop and flattened schedule (logical times + // shown on the left): + // + // 0: a = ... + // 1: ... + // cond { + // 2: p = param(0) + // 3: ... + // } + // body { + // 4: p = param(0) + // 5: ... + // 6: ROOT ... + // } + // 7: w = while(a), body=body, cond=cond + // + // When processing "a" (time 0) and its while use (time 7), we + // update the interval to time 0-4. This is so that the remaining + // interval (5-6) can be allocated separately and this buffer + // doesn't waste alternate memory space within the while loop body. + HloComputation* while_body = use.instruction->while_body(); + // Replace the use time with the parameter time so that we can + // decide on alternate memory allocations within the while loop body + // when we look at uses within the while loop body. + use_time = + instruction_schedule.at(while_body->parameter_instruction(0)); + } + } + + // Add a required assignment in default memory if the use not allowed in + // alternate memory. + if (!IsUseAllowedInAlternateMemory(use)) { + required_assignments_[allocation_value.value()].push_back( + {MemorySpace::kDefault, use_time, /*chunk=*/absl::nullopt}); } // Bitcasts don't define buffers and don't directly consume buffers. @@ -480,31 +695,50 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { AllocationRequest request; request.start_time = definition_time; request.end_time = use_time; - request.use_times = &use_times; request.latest_prefetch_time = latest_prefetch_time; + request.size = interval.size; + request.preferred_offset = preferred_offset; request.use = use; - request.buffer = value; - request.size = colocated_interval->size; - request.allocations = allocation_sequence; + request.allocation_value = &allocation_value; if (!FindAllocation(request)) { // If the allocation finding failed (e.g., due to running out of // asynchronous copies), then fall back to allocating the buffer // entirely in the default memory. UncommitPendingChunks(); - allocation_sequence->clear(); + allocation_success = false; break; } // If there are multiple uses, they can try using the memory // allocation already at the alternate memory. - definition_time = use_time; + definition_time = instruction_schedule.at(use.instruction); } // If the use has been a sequential call (e.g. a while loop), the other // colocated intervals must alias with this allocation. if (is_sequential_call) { - aliased_allocation = - GetLiveAllocationAt(*allocation_sequence, use_time); + MemorySpaceAssignment::Allocation* aliased_allocation = + GetLiveAllocationAt(*allocation_value.allocation_sequence(), + use_time); + AddAliasedRequiredAssignmentsForSequentialCall(use, + aliased_allocation); + // Remember the preferred offset to be used inside while loop body + // computations. + if (aliased_allocation->memory_space() == MemorySpace::kAlternate && + use.instruction->opcode() == HloOpcode::kWhile) { + preferred_offset_for_computation[use.instruction->while_body()] = + aliased_allocation->chunk().offset; + } + } + } + if (!allocation_success) { + break; + } + } + if (allocation_success) { + for (AllocationValue& allocation_value : allocation_values) { + for (auto& allocation : *allocation_value.allocation_sequence()) { + allocations_->push_back(std::move(allocation)); } } } @@ -513,21 +747,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { pending_async_copies_.clear(); } - if (VLOG_IS_ON(3)) { - for (const auto& value_and_sequence : *allocation_sequence_list_) { - VLOG(3) << "Allocation for " << value_and_sequence.value->ToShortString(); - for (const auto& alloc : value_and_sequence.sequence) { - std::string addr_str = ": default"; - if (alloc->memory_space() == MemorySpace::kAlternate) { - addr_str = absl::StrCat(": alt ", alloc->chunk().offset); - } - - VLOG(3) << " " << alloc->start_time() << "-" << alloc->end_time() - << addr_str << ", " << alloc->uses().size() << " uses"; - } - } - } - return result_; } @@ -591,15 +810,12 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( int64 parameter = buffer->instruction()->parameter_number(); module->AddCrossProgramPrefetch(parameter, buffer->index()); - allocation_sequence_list_->push_back({buffer, {}}); - MemorySpaceAssignment::AllocationSequence& allocations = - allocation_sequence_list_->back().sequence; - + MemorySpaceAssignment::AllocationSequence allocations; allocations.push_back(absl::make_unique( buffer->defining_position(), MemorySpace::kDefault, kDummyChunk, prefetch_candidate->start, prefetch_candidate->end)); - // Sort the uses by the use time. + // Find the earliest use. const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); auto uses = buffer->uses(); auto first_use = @@ -613,11 +829,79 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( chunk_candidate.chunk, prefetch_candidate->start, prefetch_candidate->end, latest_prefetch_time, &allocations); absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); + for (auto& allocation : allocations) { + allocations_->push_back(std::move(allocation)); + } pending_chunks_.clear(); pending_async_copies_.clear(); } +void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall( + const HloUse& use, + const MemorySpaceAssignment::Allocation* aliased_allocation) { + // Add aliased required assignments. + if (use.instruction->opcode() == HloOpcode::kWhile) { + HloComputation* while_body = use.instruction->while_body(); + AddAliasedRequiredAssignment(while_body->parameter_instruction(0), + use.operand_index, aliased_allocation); + AddAliasedRequiredAssignment(while_body->root_instruction(), + use.operand_index, aliased_allocation); + AddAliasedRequiredAssignment(use.instruction, use.operand_index, + aliased_allocation); + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + HloComputation* called_computation = + use.instruction->called_computations().at(use.operand_number - 1); + AddAliasedRequiredAssignment(called_computation->parameter_instruction(0), + use.operand_index, aliased_allocation); + } else { + CHECK(use.instruction->opcode() == HloOpcode::kCall); + HloComputation* called_computation = + use.instruction->called_computations().at(0); + AddAliasedRequiredAssignment( + called_computation->parameter_instruction(use.operand_number), + use.operand_index, aliased_allocation); + } +} + +void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( + const HloInstruction* instruction, ShapeIndex index, + const MemorySpaceAssignment::Allocation* aliased_allocation) { + absl::optional chunk; + if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { + chunk = aliased_allocation->chunk(); + } + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + HloValue* value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); + int64 instruction_time = instruction_schedule.at(instruction); + // Check for existing required assignment at this time and make sure it is the + // same as this if there is one. + auto existing_required_assignment = + RequiredMemoryAssignmentAt(value, instruction_time); + if (existing_required_assignment) { + CHECK(aliased_allocation->memory_space() == + existing_required_assignment->memory_space); + CHECK((!chunk && !existing_required_assignment->chunk) || + chunk->offset == existing_required_assignment->chunk->offset); + VLOG(3) << "Not adding aliased required assignment because there is one " + "already: " + << value->ToShortString() << " at " << instruction_time << " at " + << (aliased_allocation->memory_space() == MemorySpace::kDefault + ? "def" + : "alt"); + return; + } + + required_assignments_[value].push_back( + {aliased_allocation->memory_space(), instruction_time, chunk}); + VLOG(3) << "Adding aliased required assignment: " << value->ToShortString() + << " at " << instruction_time << " at " + << (aliased_allocation->memory_space() == MemorySpace::kDefault + ? "def" + : "alt"); +} + void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { // Go through the parameters and outputs and pin them to the corresponding // memory by adding a required assignment. @@ -708,6 +992,8 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { for (const auto& interval_and_chunk : pending_chunks_) { const BufferInterval& interval = interval_and_chunk.first; const Chunk& chunk = interval_and_chunk.second.chunk; + VLOG(4) << "Uncommitting: (" << interval.start << ", " << interval.end + << ") off = " << chunk.offset << " size = " << chunk.size; interval_tree_.Remove(interval.start, interval.end, chunk); } for (const auto& interval : pending_async_copies_) { @@ -731,69 +1017,104 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks( CommitChunk(buffer_interval, chunk_candidate); } -bool AlternateMemoryBestFitHeap::RequiredInDefaultMemory(const HloValue* buffer, - int64 time) const { +absl::optional +AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, + int64 time) const { auto required_assignment_it = required_assignments_.find(buffer); - return required_assignment_it != required_assignments_.end() && - absl::c_any_of( - required_assignment_it->second, - [&](const RequiredMemoryAssignment& required_assignment) { - return required_assignment.memory_space == - MemorySpace::kDefault && - required_assignment.time == time; - }); + absl::optional required_assignment_at_time; + if (required_assignment_it != required_assignments_.end()) { + for (const RequiredMemoryAssignment& required_assignment : + required_assignment_it->second) { + if (required_assignment.time == time) { + // Sanity check that there is only one required at time. + CHECK(!required_assignment_at_time); + required_assignment_at_time = required_assignment; + } + } + } + return required_assignment_at_time; } bool AlternateMemoryBestFitHeap::FindAllocation( const AllocationRequest& request) { + auto allocation_sequence = request.allocation_value->allocation_sequence(); // start_time == end_time is a special case where the value is consumed // multiple times by the same instruction. We can just find the previous // allocation and use that allocation. if (request.start_time == request.end_time) { MemorySpaceAssignment::Allocation* allocation = - GetLiveAllocationAt(*request.allocations, request.end_time); + GetLiveAllocationAt(*allocation_sequence, request.end_time); CHECK_NE(allocation, nullptr); allocation->AddUse(request.use); return true; } - const HloPosition& defining_position = request.buffer->defining_position(); - VLOG(2) << "Finding allocation for " << request.buffer->ToShortString() - << " (" << request.start_time << ", " << request.end_time + const HloPosition& defining_position = + request.allocation_value->defining_position(); + VLOG(2) << "Finding allocation for " + << request.allocation_value->ToShortString() << " (" + << request.start_time << ", " << request.end_time << ") latest prefetch = " << request.latest_prefetch_time - << " last use = " << request.use_times->back() + << " last use = " << request.allocation_value->use_times().back() << " use = " << request.use.ToString() << ". Size = " << request.size << ", def pos = " << defining_position.ToString(); CHECK_LE(request.start_time, request.end_time); // There could be a requirement to pin this buffer to default memory either // because it is a parameter or an output. If the buffer is a parameter, then - // we're allowed to prefetch. If the use expects the ouput to be in default + // we're allowed to prefetch. If the use expects the output to be in default // memory, we cannot prefetch it because if we did, it would be in alternate // memory instead. - bool in_default_mem_at_start = - RequiredInDefaultMemory(request.buffer, request.start_time); - bool in_default_mem_at_end = - RequiredInDefaultMemory(request.buffer, request.end_time); + auto required_assignment_at_start = RequiredMemoryAssignmentAt( + request.allocation_value->value(), request.start_time); + absl::optional required_memory_space_at_start; + if (required_assignment_at_start) { + required_memory_space_at_start = required_assignment_at_start->memory_space; + } + auto required_assignment_at_end = RequiredMemoryAssignmentAt( + request.allocation_value->value(), request.end_time); + absl::optional required_memory_space_at_end; + if (required_assignment_at_end) { + required_memory_space_at_end = required_assignment_at_end->memory_space; + } + + if (required_assignment_at_start) { + if (!allocation_sequence->empty() && + required_assignment_at_start->memory_space == MemorySpace::kAlternate) { + const auto& prev_allocation = allocation_sequence->back(); + CHECK(prev_allocation->memory_space() == + required_assignment_at_start->memory_space); + CHECK_EQ(prev_allocation->chunk().offset, + required_assignment_at_start->chunk->offset); + prev_allocation->Extend(request.start_time); + } else { + allocation_sequence->push_back( + absl::make_unique( + defining_position, required_assignment_at_start->memory_space, + required_assignment_at_start->chunk, request.start_time, + request.start_time)); + } + } // First try keeping the allocation entirely in the alternate memory. - if (!in_default_mem_at_start && !in_default_mem_at_end && + if (required_memory_space_at_start != MemorySpace::kDefault && + required_memory_space_at_end != MemorySpace::kDefault && AllocateInAlternateMemoryNoCopy(request)) { return true; } - auto prev_allocation_it = request.allocations->rbegin(); + auto prev_allocation_it = allocation_sequence->rbegin(); // Find a previous allocation that is in the default memory space (not // necessarily the very last allocation). auto prev_allocation_in_default_mem_it = std::find_if( - request.allocations->rbegin(), request.allocations->rend(), + allocation_sequence->rbegin(), allocation_sequence->rend(), [&](const auto& allocation) { return allocation->memory_space() == MemorySpace::kDefault && allocation->defining_position() == defining_position; }); - if (prev_allocation_in_default_mem_it == request.allocations->rend() && - prev_allocation_it != request.allocations->rend() && + if (prev_allocation_in_default_mem_it == allocation_sequence->rend() && + prev_allocation_it != allocation_sequence->rend() && (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate && (*prev_allocation_it)->defining_position() == defining_position) { // If there was an allocation for this HloValue that was in the alternate @@ -801,21 +1122,21 @@ bool AlternateMemoryBestFitHeap::FindAllocation( if (!Evict(request)) { return false; } - prev_allocation_in_default_mem_it = request.allocations->rbegin(); - } else if (prev_allocation_in_default_mem_it == request.allocations->rend()) { - request.allocations->push_back( + prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); + } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { + allocation_sequence->push_back( absl::make_unique( - defining_position, MemorySpace::kDefault, kDummyChunk, + defining_position, MemorySpace::kDefault, /*chunk=*/absl::nullopt, request.start_time, request.end_time)); - prev_allocation_in_default_mem_it = request.allocations->rbegin(); + prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } - CHECK(prev_allocation_in_default_mem_it != request.allocations->rend()); + CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend()); CHECK((*prev_allocation_in_default_mem_it)->memory_space() == MemorySpace::kDefault); // If the buffer must be in default memory at the end_time, don't prefetch. - if (in_default_mem_at_end) { + if (required_memory_space_at_end == MemorySpace::kDefault) { VLOG(4) << "Not trying to prefetch because use requires buffer in default mem."; (*prev_allocation_in_default_mem_it)->Extend(request.end_time); @@ -828,6 +1149,12 @@ bool AlternateMemoryBestFitHeap::FindAllocation( return true; } + // If the end assignment was required to be in alternate memory but that + // wasn't possible, then this allocation is invalid. + if (required_memory_space_at_end == MemorySpace::kAlternate) { + return false; + } + // If a copy wasn't inserted, then add this use to the latest allocation in // default memory. (*prev_allocation_in_default_mem_it)->Extend(request.end_time); @@ -837,8 +1164,8 @@ bool AlternateMemoryBestFitHeap::FindAllocation( void AlternateMemoryBestFitHeap::AddAsyncCopy( const MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time, - int64 copy_done_schedule_before_time, + MemorySpace memory_space, absl::optional chunk, int64 start_time, + int64 end_time, int64 copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations) { VLOG(3) << "Copy to " << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault @@ -886,15 +1213,16 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( const AllocationRequest& request) { MemorySpaceAssignment::Allocation* prev_allocation = nullptr; bool can_eliminate_copy = false; - if (request.allocations->empty()) { + if (request.allocation_value->allocation_sequence()->empty()) { // There hasn't been any allocations for this interval so far. We can // eliminate copy if the value can be placed in the alternate memory. - can_eliminate_copy = - options_.is_allowed_in_alternate_mem_fn(*request.buffer); + can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn( + *request.allocation_value->value()); } else { // If there has been a previous allocation, we can eliminate the copy if the // previous allocation was also in the alternate memory. - prev_allocation = request.allocations->back().get(); + prev_allocation = + request.allocation_value->allocation_sequence()->back().get(); can_eliminate_copy = (prev_allocation->memory_space() == MemorySpace::kAlternate); } @@ -903,15 +1231,16 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( return false; } - const HloPosition& defining_position = request.buffer->defining_position(); + const HloPosition& defining_position = + request.allocation_value->defining_position(); if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( - defining_position.instruction->shape(), request.start_time + 1, + defining_position.shape(), request.start_time + 1, request.end_time)) { return false; } BufferInterval alternate_mem_interval; - alternate_mem_interval.buffer = request.buffer; + alternate_mem_interval.buffer = request.allocation_value->value(); alternate_mem_interval.size = request.size; alternate_mem_interval.end = request.end_time; alternate_mem_interval.start = request.start_time; @@ -925,6 +1254,15 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( alternate_mem_interval.start = prev_allocation->end_time() + 1; } + if (request.preferred_offset) { + // Sanity check that if there is a preferred offset provided in the request, + // it matches with the previous allocation. + CHECK(!preferred_offset || request.preferred_offset == preferred_offset) + << "preferred_offset = " << *preferred_offset + << ", request.preferred_offset = " << *request.preferred_offset; + preferred_offset = request.preferred_offset; + } + VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = " << (preferred_offset ? *preferred_offset : -1); // In case there are additional uses after this use, we rely on the last use @@ -948,9 +1286,8 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( // for the entire live range. This can result in unnecessary copies. By using // the last use time, we try to find an allocation that is available for the // entire Producer to Use2 range. - absl::optional chunk_candidate = - FindBestChunkCandidate(request.end_time, *request.use_times, - preferred_offset, &alternate_mem_interval); + absl::optional chunk_candidate = FindBestChunkCandidate( + request, preferred_offset, &alternate_mem_interval); // Check if the new heap size fits within limits. Also ensure if a // preferred offset was provided, that offset was used. if (chunk_candidate) { @@ -960,7 +1297,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( << ", heap_size = " << chunk_candidate->heap_size << ", prefetch picker = " << options_.prefetch_interval_picker->ToNoCopyDebugString( - defining_position.instruction->shape(), request.start_time, + defining_position.shape(), request.start_time, request.end_time); AddToPendingChunks(alternate_mem_interval, *chunk_candidate); @@ -971,38 +1308,40 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( prev_allocation->defining_position() == defining_position)) { prev_allocation->Extend(request.end_time); } else { - request.allocations->push_back( + request.allocation_value->allocation_sequence()->push_back( absl::make_unique( defining_position, MemorySpace::kAlternate, chunk_candidate->chunk, request.start_time, request.end_time)); } - request.allocations->back()->AddUse(request.use); + request.allocation_value->allocation_sequence()->back()->AddUse( + request.use); return true; } return false; } bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { - CHECK_GT(request.allocations->size(), 0); + CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); MemorySpaceAssignment::Allocation* prev_allocation = - request.allocations->back().get(); + request.allocation_value->allocation_sequence()->back().get(); int64 eviction_start_time = prev_allocation->start_time(); int64 eviction_end_time = prev_allocation->end_time(); CHECK(eviction_start_time <= eviction_end_time); int64 preferred_eviction_end_time = std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime( - request.buffer->instruction()->shape(), eviction_start_time, - request.end_time), + request.allocation_value->defining_position().shape(), + eviction_start_time, request.end_time), eviction_end_time); BufferInterval eviction_mem_interval; - eviction_mem_interval.buffer = request.buffer; + eviction_mem_interval.buffer = request.allocation_value->value(); eviction_mem_interval.size = request.size; // Try to reserve a buffer from the end of the previous allocation to the // preferred eviction end time. eviction_mem_interval.start = eviction_end_time + 1; - eviction_mem_interval.end = preferred_eviction_end_time; + eviction_mem_interval.end = + std::min(preferred_eviction_end_time, global_max_time_); int64 preferred_offset = prev_allocation->chunk().offset; VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time << ") preferred end time = " << eviction_mem_interval.end; @@ -1029,9 +1368,10 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { // See if this interval would violate the asynchronous copy limit. if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) { prev_allocation->Extend(eviction_end_time); - AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, - eviction_start_time, prev_allocation->end_time(), - eviction_end_time, request.allocations); + AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, + /*chunk=*/absl::nullopt, eviction_start_time, + prev_allocation->end_time(), eviction_end_time, + request.allocation_value->allocation_sequence()); } else { if (eviction_violates_outstanding_copies) { VLOG(3) << "This violates the maximum async copies."; @@ -1046,8 +1386,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { VLOG(3) << "Try evicting (" << time << ", " << time + 1 << ")"; if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) { VLOG(3) << "Eviction successful."; - AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, time, - time + 1, time + 1, request.allocations); + AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, + /*chunk=*/absl::nullopt, time, time + 1, time + 1, + request.allocation_value->allocation_sequence()); eviction_scheduled = true; break; } @@ -1094,7 +1435,7 @@ bool AlternateMemoryBestFitHeap::Prefetch( // Create an alternate memory interval that starts at the earliest // possible position, given by max_prefetch_interval. BufferInterval alternate_mem_interval; - alternate_mem_interval.buffer = request.buffer; + alternate_mem_interval.buffer = request.allocation_value->value(); alternate_mem_interval.size = request.size; while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); @@ -1114,8 +1455,7 @@ bool AlternateMemoryBestFitHeap::Prefetch( } auto chunk_candidate = FindBestChunkCandidate( - request.end_time, *request.use_times, - /*preferred_offset=*/absl::nullopt, &alternate_mem_interval); + request, request.preferred_offset, &alternate_mem_interval); // Check if we could find a suitable chunk. if (chunk_candidate) { VLOG(3) << "Move the buffer to alternate memory at " @@ -1130,9 +1470,10 @@ bool AlternateMemoryBestFitHeap::Prefetch( AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate, chunk_candidate->chunk, alternate_mem_interval.start, request.end_time, request.latest_prefetch_time, - request.allocations); + request.allocation_value->allocation_sequence()); - request.allocations->back()->AddUse(request.use); + request.allocation_value->allocation_sequence()->back()->AddUse( + request.use); return true; } } @@ -1141,14 +1482,16 @@ bool AlternateMemoryBestFitHeap::Prefetch( absl::optional AlternateMemoryBestFitHeap::FindBestChunkCandidate( - int64 end_time, const std::vector& use_times, - absl::optional preferred_offset, + const AllocationRequest& request, absl::optional preferred_offset, BufferInterval* alternate_mem_interval) const { + int64 end_time = request.end_time; if (!preferred_offset) { // Find a chunk that's as long living as possible iterating in reverse over // the use times. - for (auto use_time = use_times.rbegin(); - use_time != use_times.rend() && *use_time >= end_time; ++use_time) { + for (auto use_time = request.allocation_value->use_times().rbegin(); + use_time != request.allocation_value->use_times().rend() && + *use_time >= end_time; + ++use_time) { alternate_mem_interval->end = *use_time; ChunkCandidate chunk_candidate = FindChunkCandidate(*alternate_mem_interval); @@ -1360,8 +1703,8 @@ MemorySpaceAssignment::Run(HloModule* module, MemorySpaceAssignment memory_space_assignment(module, options, hlo_live_range); auto algorithm = absl::make_unique( - &memory_space_assignment.allocation_sequence_list_, options, - alias_analysis, hlo_live_range); + &memory_space_assignment.allocations_, options, alias_analysis, + hlo_live_range); if (options.enable_cross_program_prefetch) { absl::optional prefetch_candiate = @@ -1426,6 +1769,27 @@ void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { Status MemorySpaceAssignment::Allocation::Process( MemorySpaceAssignment* memory_space_assignment) { + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + for (const HloUse& use : uses_) { + Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); + HloInstruction* replacement_instruction = producing_instruction; + if (operand_shape.IsTuple()) { + TF_ASSIGN_OR_RETURN( + replacement_instruction, + ReplaceTupleWith(producing_instruction, + use.instruction->mutable_operand(use.operand_number), + use.operand_index)); + } else if (operand_shape != producing_instruction->shape()) { + VLOG(4) << "Old shape = " << operand_shape.ToString() + << ", new shape = " << producing_instruction->shape().ToString() + << "; inserting a bitcast."; + replacement_instruction = computation->AddInstruction( + HloInstruction::CreateBitcast(operand_shape, producing_instruction)); + } + TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( + use.operand_number, replacement_instruction)); + } return Status::OK(); } @@ -1472,30 +1836,34 @@ StatusOr MemorySpaceAssignment::Allocation::ReplaceTupleWith( return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args)); } -Status MemorySpaceAssignment::CopyAllocation::Process( - MemorySpaceAssignment* memory_space_assignment) { - // Copy allocations need to insert asynchronous copy nodes. +HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() { HloInstruction* producing_instruction = defining_position().instruction; CHECK_NE(producing_instruction, nullptr); Shape shape = defining_position().shape(); - CHECK(shape.IsArray()) << "CopyAllocation shape is not an array. Shape = " + CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = " << shape.ToString() << " position = " << defining_position().shape(); HloComputation* computation = producing_instruction->parent(); - // If the instruction we're copying from is a tuple, we (recursively) create + // If the instruction we're processing is a tuple, we (recursively) create // kGetTupleElement instructions and copy that value. Asynchronous copies only // support array types. - if (!producing_instruction->shape().IsArray()) { - producing_instruction = defining_position().instruction; - for (int64 index : defining_position().index) { - producing_instruction = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - producing_instruction->shape().tuple_shapes(index), - producing_instruction, index)); - } + for (int64 index : defining_position().index) { + producing_instruction = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + producing_instruction->shape().tuple_shapes(index), + producing_instruction, index)); } + return producing_instruction; +} + +Status MemorySpaceAssignment::CopyAllocation::Process( + MemorySpaceAssignment* memory_space_assignment) { + // Copy allocations need to insert asynchronous copy nodes. + Shape shape = defining_position().shape(); + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), HloOpcode::kCopyStart, producing_instruction)); @@ -1536,17 +1904,38 @@ Status MemorySpaceAssignment::CopyAllocation::Process( Status MemorySpaceAssignment::Process() { // Insert CopyStart/CopyDone pairs. int64 alternate_memory_size = 0; - for (auto& value_and_sequence : allocation_sequence_list_) { - for (auto& allocation : value_and_sequence.sequence) { - TF_RETURN_IF_ERROR(allocation->Process(this)); - // Add the offset and size of the allocation in the alternate memory to - // the output map. - if (allocation->memory_space() == MemorySpace::kAlternate) { - preset_assignments_->add_chunk(allocation->defining_position(), + std::vector> position_and_chunks; + for (auto& allocation : allocations_) { + TF_RETURN_IF_ERROR(allocation->Process(this)); + // Add the offset and size of the allocation in the alternate memory to + // the output map. + if (allocation->memory_space() == MemorySpace::kAlternate) { + position_and_chunks.emplace_back(allocation->defining_position(), allocation->chunk()); - alternate_memory_size = - std::max(alternate_memory_size, allocation->chunk().chunk_end()); - } + alternate_memory_size = + std::max(alternate_memory_size, allocation->chunk().chunk_end()); + } + } + + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); + absl::flat_hash_map seen_buffer_offsets; + VLOG(3) << "Exported alternate memory allocations:"; + for (const auto& position_and_chunk : position_and_chunks) { + const HloPosition& defining_position = position_and_chunk.first; + const Chunk& chunk = position_and_chunk.second; + const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt( + defining_position.instruction, defining_position.index); + auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id()); + if (seen_buffer_offset_it != seen_buffer_offsets.end()) { + CHECK_EQ(chunk.offset, seen_buffer_offset_it->second) + << "Mismatch in offset for positions that map to the same value: " + << buffer.ToString() << ", pos: " << defining_position.ToString(); + } else { + VLOG(3) << " [" << chunk.offset << ", " << chunk.size + << "] : " << defining_position.ToString() << " (" + << buffer.ToString() << ")"; + preset_assignments_->add_chunk(defining_position, chunk); + seen_buffer_offsets[buffer.id()] = chunk.offset; } } @@ -1556,20 +1945,12 @@ Status MemorySpaceAssignment::Process() { ->size = alternate_memory_size; } - if (VLOG_IS_ON(3)) { - VLOG(3) << "Exported alternate memory allocations:"; - for (auto& pair : preset_assignments_->chunks()) { - VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size - << "] : " << pair.first.ToString(); - } - VLOG(3) << "Exported alternate memory sizes:"; - for (auto& pair : preset_assignments_->assignment_informations()) { - VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size; - } + VLOG(3) << "Exported alternate memory sizes:"; + for (auto& pair : preset_assignments_->assignment_informations()) { + VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size; } // Color the pending positions and all of their aliased buffers. - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); for (const auto& defining_position_and_chunk : preset_assignments_->chunks()) { const HloPosition& defining_position = defining_position_and_chunk.first; @@ -1666,6 +2047,38 @@ Status MemorySpaceAssignment::SimplifyGraph() { instruction->ReplaceAllUsesWith(forwarded_instruction)); computation_modified = true; } + } else if (instruction->opcode() == HloOpcode::kTuple) { + // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern + // with x. + bool can_replace = + instruction->operand_count() > 0 && + instruction->operand(0)->opcode() == + HloOpcode::kGetTupleElement && + instruction->operand(0) + ->operand(0) + ->shape() + .tuple_shapes_size() == instruction->operand_count(); + for (int operand_number = 0; + operand_number < instruction->operand_count(); + ++operand_number) { + const HloInstruction* operand = + instruction->operand(operand_number); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->tuple_index() != operand_number || + operand->operand(0) != instruction->operand(0)->operand(0)) { + can_replace = false; + break; + } + } + if (can_replace) { + HloInstruction* forwarded_instruction = + instruction->mutable_operand(0)->mutable_operand(0); + VLOG(4) << "Replacing uses of " << instruction->ToString() + << " with " << forwarded_instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(forwarded_instruction)); + computation_modified = true; + } } } } @@ -1700,13 +2113,11 @@ void MemorySpaceAssignment::ScheduleAsynchronousCopies() { for (MemorySpace memory_space : {MemorySpace::kDefault, MemorySpace::kAlternate}) { std::vector copy_allocations; - for (auto& value_and_sequence : allocation_sequence_list_) { - for (auto& allocation : value_and_sequence.sequence) { - if (allocation->is_copy_allocation()) { - auto copy_allocation = static_cast(allocation.get()); - if (copy_allocation->memory_space() == memory_space) { - copy_allocations.push_back(copy_allocation); - } + for (auto& allocation : allocations_) { + if (allocation->is_copy_allocation()) { + auto copy_allocation = static_cast(allocation.get()); + if (copy_allocation->memory_space() == memory_space) { + copy_allocations.push_back(copy_allocation); } } } @@ -1838,50 +2249,43 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { const Chunk& chunk = position_and_chunk.second; const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(position.instruction, position.index); - if (seen_buffers.contains(buffer.id())) { - continue; - } + CHECK(!seen_buffers.contains(buffer.id())) + << "Multiple preset assignments for the same buffer: " + << buffer.ToString() << ", pos: " << position.ToString() + << ", off: " << chunk.offset << ", size: " << chunk.size; seen_buffers.insert(buffer.id()); - int64 start_time = INT64_MAX; - int64 end_time = -1; for (const HloValue* value : buffer.values()) { const HloLiveRange::TimeBound& time_bound = hlo_live_range->buffer_live_ranges().at(value); - VLOG(3) << " value: " << value->ToShortString() << " (" - << time_bound.start << ", " << time_bound.end << ")"; - start_time = std::min(start_time, time_bound.start); - end_time = std::max(end_time, time_bound.end); events[std::make_pair(time_bound.start, value->id())] = std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); events[std::make_pair(time_bound.end, value->id())] = std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); - } - CHECK_GE(start_time, 0); - CHECK_GT(end_time, 0); - // Get the chunks overlapping in time and search if they overlap in space as - // well. - // TODO(berkin): For now checking against end_time - 1 (exclusive), but we - // really should check against end_time (inclusive) for cases where the - // operand can't share buffer with user (see - // HloDataflowAnalysis::CanShareOperandBufferWithUser). - if (options_.verify || VLOG_IS_ON(1)) { - // Verify only if the option is set or if vlog is on. + + VLOG(3) << " buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + // Get the chunks overlapping in time and search if they overlap in space + // as well. + // TODO(berkin): For now checking against end_time - 1 (exclusive), but we + // really should check against end_time (inclusive) for cases where the + // operand can't share buffer with user (see + // HloDataflowAnalysis::CanShareOperandBufferWithUser). for (const Chunk& overlapping_chunk : - interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { + interval_tree.ChunksOverlappingInTime(time_bound.start, + time_bound.end - 1)) { if (chunk.OverlapsWith(overlapping_chunk)) { return InternalError( ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk" " off: %d size: %d"), - buffer.ToString(), start_time, end_time, chunk.offset, chunk.size, - overlapping_chunk.offset, overlapping_chunk.size); + buffer.ToString(), time_bound.start, time_bound.end, chunk.offset, + chunk.size, overlapping_chunk.offset, overlapping_chunk.size); } } + interval_tree.Add(time_bound.start, time_bound.end - 1, chunk); } - interval_tree.Add(start_time, end_time - 1, chunk); - VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", " - << end_time << ") off: " << position_and_chunk.second.offset - << ", size: " << position_and_chunk.second.size; } HeapSimulatorTrace* heap_trace = diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 4a24f60b6a9..5a897d6cefa 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -363,7 +363,7 @@ class MemorySpaceAssignment { class Allocation { public: Allocation(HloPosition defining_position, MemorySpace memory_space, - Chunk chunk, int64 start_time, int64 end_time) + absl::optional chunk, int64 start_time, int64 end_time) : defining_position_(defining_position), memory_space_(memory_space), chunk_(chunk), @@ -393,7 +393,7 @@ class MemorySpaceAssignment { const std::vector& uses() const { return uses_; } MemorySpace memory_space() const { return memory_space_; } - Chunk chunk() const { return chunk_; } + Chunk chunk() const { return *chunk_; } void set_start_time(int64 start_time) { start_time_ = start_time; } int64 start_time() const { return start_time_; } int64 end_time() const { return end_time_; } @@ -405,10 +405,14 @@ class MemorySpaceAssignment { HloInstruction* tuple, ShapeIndex shape_index); + // Recursively create kGetTupleElement instructions if the defining position + // shape is not an array. Returns the new instruction that has array shape. + HloInstruction* AddGetTupleElements(); + HloPosition defining_position_; std::vector uses_; MemorySpace memory_space_; - Chunk chunk_; + absl::optional chunk_; int64 start_time_; int64 end_time_; }; @@ -417,8 +421,8 @@ class MemorySpaceAssignment { class CopyAllocation : public Allocation { public: CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, - Chunk chunk, int64 start_time, int64 end_time, - int64 copy_done_schedule_before_time) + absl::optional chunk, int64 start_time, + int64 end_time, int64 copy_done_schedule_before_time) : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, start_time, end_time), prev_allocation_(prev_allocation), @@ -476,11 +480,105 @@ class MemorySpaceAssignment { }; using AllocationSequence = std::vector>; - struct ValueAndAllocationSequence { - const HloValue* value; - AllocationSequence sequence; + // AllocationValue is used to break up HloValues for each non-trivial position + // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An + // HloValue may include positions and uses that alias with each other across + // multiple computations. We use this class to break these HloValues such that + // every AllocationValue has one defining position (that may alias with other + // AllocationValues). The uses field of the AllocationValue contains only the + // direct uses of the AllocationValue's defining position. + // + // For example, consider the following HLO snippet: + // + // Body { + // body_param = (f32[4,3]{1,0}, f32[]) parameter(0) + // get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param), + // index=0 + // ... + // ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...) + // } + // + // Cond { + // cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) + // ... + // } + // + // add.4 = f32[4,3]{1,0} add(...) + // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...) + // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond + // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0 + // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...) + // + // This contains an HloValue that looks like the following: + // positions: + // add.4 + // body_param {0} + // get-tuple-element.3 + // tuple {0} + // cond_param {0} + // tuple.1 {0} + // while {0} + // get-tuple-element.5 + // uses: + // add.1, operand 0 + // tuple, operand 0 + // while, operand 0 {0} + // add.5, operand 0 + // + // We break this HloValue up into the following AllocationValues for each + // non-trivial position: + // AllocationValue1: computation = Entry + // position: + // add.4 + // uses: + // while, operand 0 {0} + // AllocationValue2: computation = Cond + // position: + // cond_param {0} + // uses: + // AllocationValue3: computation = Body + // position: + // body_param {0} + // uses: + // add.1, operand 0 + // tuple, operand 0 + // AllocationValue4: computation = Entry + // position: + // while {0} + // uses: + // add.5, operand 0 + class AllocationValue { + public: + AllocationValue(const HloValue* value, const HloPosition& position) + : value_(value), defining_position_(position) {} + + const HloPosition& defining_position() const { return defining_position_; } + const HloInstruction* defining_instruction() const { + return defining_position().instruction; + } + const std::vector& uses() const { return uses_; } + const std::vector& use_times() const { return use_times_; } + const HloValue* value() const { return value_; } + const HloComputation* computation() const { + return defining_instruction()->parent(); + } + AllocationSequence* allocation_sequence() { return &allocation_sequence_; } + + void AddUse(const HloUse& use, int64 use_time) { + uses_.push_back(use); + use_times_.push_back(use_time); + } + + std::string ToString() const; + std::string ToShortString() const; + + private: + const HloValue* value_; + HloPosition defining_position_; + std::vector uses_; + std::vector use_times_; + AllocationSequence allocation_sequence_; }; - using AllocationSequenceList = std::vector; // Runs the MemorySpaceAssignment pass. static StatusOr> Run( @@ -545,7 +643,7 @@ class MemorySpaceAssignment { Options options_; std::vector flattened_instructions_; absl::flat_hash_set computations_in_schedule_; - AllocationSequenceList allocation_sequence_list_; + AllocationSequence allocations_; std::unique_ptr preset_assignments_; // These maps hold vectors of new instructions that need to be scheduled after @@ -562,6 +660,7 @@ class MemorySpaceAssignment { struct RequiredMemoryAssignment { MemorySpaceAssignment::MemorySpace memory_space; int64 time; + absl::optional chunk; }; // A struct representing an asynchronous copy with its logical start and end @@ -614,14 +713,15 @@ class AsynchronousCopyOrdering { class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { public: using MemorySpace = MemorySpaceAssignment::MemorySpace; + using AllocationValue = MemorySpaceAssignment::AllocationValue; AlternateMemoryBestFitHeap( - MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list, + MemorySpaceAssignment::AllocationSequence* allocations, const MemorySpaceAssignment::Options& options, const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range) : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes), - allocation_sequence_list_(allocation_sequence_list), + allocations_(allocations), options_(options), alias_analysis_(alias_analysis), hlo_live_range_(hlo_live_range) { @@ -632,7 +732,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { } // Allocates a buffer in preferred memory with whole program lifetime and - // enables prefetching prefech_candidate from default memory across program + // enables prefetching prefetch_candidate from default memory across program // boundaries. void AllocateCrossProgramPrefetchBuffer( HloModule* module, absl::optional prefetch_candidate); @@ -660,12 +760,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { struct AllocationRequest { int64 start_time; int64 end_time; - const std::vector* use_times; int64 latest_prefetch_time; int64 size; + absl::optional preferred_offset; HloUse use; - const HloValue* buffer; - MemorySpaceAssignment::AllocationSequence* allocations; + MemorySpaceAssignment::AllocationValue* allocation_value; }; // Given an allocation sequence, returns the live allocation at time with a @@ -674,15 +773,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( const MemorySpaceAssignment::AllocationSequence& allocations, int64 time); - // Returns true if a buffer is required to be in default memory at a - // particular time. A buffer may be required to be in default memory because - // it is a parameter in default memory or an ouput in default memory. - bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const; + // Returns the required assignment at a particular time, if available. + absl::optional RequiredMemoryAssignmentAt( + const HloValue* buffer, int64 time) const; // Returns true if this buffer is allowed to be placed in the alternate // memory. bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const; + // Returns true if the use is allowed in the alternate memory. + bool IsUseAllowedInAlternateMemory(const HloUse& use) const; + + // Given an HloValue, creates AllocationValue objects and corresponding + // AllocationSequences and appends them into allocation_sequence_list_. + void CreateAllocationValues(const HloValue* value, + std::vector* allocation_values); + // Finds an allocation for the given interval. // // It performs three things in the following order: @@ -715,10 +821,21 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // availability if no preferred offset is given, or at the preferred_offset if // it is given. absl::optional FindBestChunkCandidate( - int64 end_time, const std::vector& use_times, - absl::optional preferred_offset, + const AllocationRequest& request, absl::optional preferred_offset, BufferInterval* alternate_mem_interval) const; + // At the end of an allocation with a sequential call (while, conditional, and + // call), this function adds the necessary aliased assignments within the + // called computations. + void AddAliasedRequiredAssignmentsForSequentialCall( + const HloUse& use, + const MemorySpaceAssignment::Allocation* aliased_allocation); + + // Propagates aliased required assignment for a given position. + void AddAliasedRequiredAssignment( + const HloInstruction* instruction, ShapeIndex index, + const MemorySpaceAssignment::Allocation* aliased_allocation); + // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); @@ -734,7 +851,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { std::vector GetSortedColocatedIntervals( const BufferInterval& interval) const; - // Since the allocations are recorded to the AllocationSequenceList, we don't + // Since the allocations are recorded to the AllocationSequence, we don't // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap // to avoid unnecessarily adding the chunk to the chunk map. void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {} @@ -749,8 +866,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpace memory_space, Chunk chunk, int64 start_time, - int64 end_time, int64 copy_done_schedule_before_time, + MemorySpace memory_space, absl::optional chunk, + int64 start_time, int64 end_time, + int64 copy_done_schedule_before_time, MemorySpaceAssignment::AllocationSequence* allocations); // This method is used for committing the chunk candidate but adding it to @@ -768,7 +886,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { return options_.max_size_in_bytes - reserved_in_bytes_; } - MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list_; + MemorySpaceAssignment::AllocationSequence* allocations_; const MemorySpaceAssignment::Options& options_; const HloAliasAnalysis& alias_analysis_; const HloLiveRange& hlo_live_range_; @@ -784,6 +902,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; + int64 global_max_time_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 85a5e7a87a2..9f92ccfef95 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -103,6 +103,21 @@ class MemorySpaceAssignmentTest : public HloTestBase, return true; }; + // Only check parameters in default memory if the original module didn't + // have the parameters in alternate memory. + bool check_parameters_in_default_memory = true; + for (const HloInstruction* parameter : + module->entry_computation()->parameter_instructions()) { + ShapeUtil::ForEachSubshape( + parameter->shape(), + [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (subshape.has_layout() && + subshape.layout().memory_space() == kAlternateMemorySpace) { + check_parameters_in_default_memory = false; + } + }); + } + MemorySpaceAssignment::Options options; options.alternate_memory_space = kAlternateMemorySpace; options.max_size_in_bytes = 128; @@ -125,6 +140,9 @@ class MemorySpaceAssignmentTest : public HloTestBase, MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis, options) .ValueOrDie(); + if (check_parameters_in_default_memory) { + CheckParametersInDefaultMemory(module); + } CheckPresetAssignments(preset_assignments.get()); return preset_assignments; } @@ -148,6 +166,24 @@ class MemorySpaceAssignmentTest : public HloTestBase, } } + void CheckParametersInDefaultMemory(const HloModule* module) { + // Check that all the entry parameter subshapes are placed in default + // memory. + const HloComputation* entry_computation = module->entry_computation(); + for (const HloInstruction* parameter : + entry_computation->parameter_instructions()) { + ShapeUtil::ForEachSubshape( + parameter->shape(), + [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (subshape.has_layout()) { + EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace) + << "Parameter not in default memory: " + << parameter->ToString(); + } + }); + } + } + std::unique_ptr CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -1250,11 +1286,282 @@ TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) { if (instruction->opcode() == HloOpcode::kWhile) { const Shape& while_subshape = ShapeUtil::GetSubshape(instruction->shape(), {0}); - EXPECT_NE(while_subshape.layout().memory_space(), kAlternateMemorySpace); + // We expect shape {0} to either be in default memory for the entire while + // loop or there has to be an eviction within the while loop. + if (while_subshape.layout().memory_space() == kAlternateMemorySpace) { + const HloInstruction* body_param = + instruction->while_body()->parameter_instruction(0); + const HloInstruction* gte = nullptr; + for (const HloInstruction* user : body_param->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 0) { + gte = user; + break; + } + } + EXPECT_NE(gte, nullptr); + const HloInstruction* copy_start = nullptr; + for (const HloInstruction* user : gte->users()) { + if (user->opcode() == HloOpcode::kCopyStart) { + copy_start = user; + break; + } + } + EXPECT_NE(copy_start, nullptr); + const Shape& copy_start_subshape = + ShapeUtil::GetSubshape(copy_start->shape(), {0}); + + EXPECT_NE(copy_start_subshape.layout().memory_space(), + kAlternateMemorySpace); + } } } } +TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) { + absl::string_view hlo_string = R"( + HloModule WhileAllocationBug, is_scheduled=true + + %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { + %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 + %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 + %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 + %constant.1 = f32[] constant(1) + %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) + %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) + %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3) + %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) + %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2) + %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) + ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) + } + + %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { + %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 + %constant = f32[] constant(50) + ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT + } + + %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { + %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 + %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 + %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 + %constant.1 = f32[] constant(1) + %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) + %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) + %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3) + %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) + %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2) + %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) + ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) + } + + %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { + %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 + %constant = f32[] constant(50) + ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT + } + + ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] { + %param_iter = f32[] parameter(1) + %param_data = f32[4,3]{1,0} parameter(0) + %p2 = f32[4,3]{1,0} parameter(2) + %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) + %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) + %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) + %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) + %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) + %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) + %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) + %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2) + %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter) + %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody + %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0 + %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4) + %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1 + %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter) + %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2 + %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0 + ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + +TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) { + // Tests against while live ranges being incorrect and the verifier + // complaining about a conflict. + absl::string_view hlo_string = R"( + HloModule WhileAllocationBug, is_scheduled=true + + %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { + %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 + %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 + %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 + %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2) + %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10) + %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11) + %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12) + %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13) + %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14) + %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15) + %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16) + %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17) + %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18) + %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19) + %constant.1 = f32[] constant(1) + %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) + %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) + %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20) + %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) + %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2) + %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) + ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) + } + + %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { + %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 + %constant = f32[] constant(50) + ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT + } + + ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] { + %param_iter = f32[] parameter(1) + %param_data = f32[4,3]{1,0} parameter(0) + %p2 = f32[4,3]{1,0} parameter(2) + %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) + %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) + %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) + %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) + %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) + %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) + %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) + %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2) + %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter) + %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody + %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0 + %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1 + %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4) + ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + +TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) { + // Tests against a bug when there are consecutive while loops with one buffer + // (the value doesn't change in the buffer), the parameter can be colored in + // the alternate memory space. + absl::string_view hlo_string = R"( + HloModule WhileAllocationBug, is_scheduled=true + + %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { + %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 + %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 + %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 + %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2) + %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10) + %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11) + %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12) + %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13) + %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14) + %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15) + %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16) + %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17) + %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18) + %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19) + %constant.1 = f32[] constant(1) + %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) + %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) + %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20) + %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) + %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2) + %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) + ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) + } + + %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { + %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 + %constant = f32[] constant(50) + ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT + } + + %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { + %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 + %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 + %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 + %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2) + %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10) + %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11) + %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12) + %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13) + %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14) + %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15) + %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16) + %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17) + %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18) + %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19) + %constant.1 = f32[] constant(1) + %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) + %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) + %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20) + %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) + %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2) + %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) + ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) + } + + %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { + %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 + %constant = f32[] constant(50) + ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT + } + + ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] { + %param_iter = f32[] parameter(1) + %param_data = f32[4,3]{1,0} parameter(0) + %p2 = f32[4,3]{1,0} parameter(2) + %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) + %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) + %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) + %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) + %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) + %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) + %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) + %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2) + %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter) + %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody + %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0 + %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4) + %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter) + %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2 + %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0 + %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1 + ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) { // Having control_predecessors on an HLO was preventing us from DCEing an op // that doesn't have any users (tuple.1). The scheduler assumes the graph is @@ -2070,12 +2377,6 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, /*max_prefetch_interval=*/25); - int64 memory_space_across_while = kDefaultMemorySpace; - bool allocate_across_sequential_calls = GetParam(); - if (allocate_across_sequential_calls) { - memory_space_across_while = kAlternateMemorySpace; - } - // Index {0} of the while loop argument is not written inside the while loop, // so it can be trivially placed in the alternate memory space. *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() = @@ -2087,12 +2388,11 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { LayoutUtil::MakeLayout( /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0, kDefaultMemorySpace); - // Index {2} of the while loop argument is placed in the alternate memory if - // we enable the allocate_across_sequential_calls option. + // Index {2} of the while loop is placed in the default memory. *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() = LayoutUtil::MakeLayout( /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, - memory_space_across_while); + kDefaultMemorySpace); // Expect the layout for the while loop and its aliased buffers. EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));