Refactors MemorySpaceAssignment to separate out the creation of the AllocationSequence.
PiperOrigin-RevId: 306322847 Change-Id: Iecccef3441c38c08f5ec42a80af778803a1d898b
This commit is contained in:
parent
867673d63b
commit
4d847bd055
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
@ -1689,6 +1688,32 @@ FindCrossProgramPrefetchCandidate(
|
||||
}
|
||||
return *best_candidate;
|
||||
}
|
||||
|
||||
// Finds an AllocationSequence for placing buffers in alternate memory using the
|
||||
// AlternateMemoryBestFitHeap algorithm.
|
||||
StatusOr<MemorySpaceAssignment::AllocationSequence> FindAllocationSequence(
|
||||
HloModule* module, const HloLiveRange& hlo_live_range,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const MemorySpaceAssignment::Options& options) {
|
||||
MemorySpaceAssignment::AllocationSequence allocations;
|
||||
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
||||
&allocations, options, alias_analysis, hlo_live_range);
|
||||
|
||||
if (options.enable_cross_program_prefetch) {
|
||||
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
|
||||
prefetch_candiate = FindCrossProgramPrefetchCandidate(
|
||||
alias_analysis, hlo_live_range, options);
|
||||
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
||||
}
|
||||
|
||||
HeapSimulator::Options heap_simulator_options;
|
||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
||||
module->schedule(), alias_analysis,
|
||||
options.size_fn, heap_simulator_options)
|
||||
.status());
|
||||
return std::move(allocations);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
||||
@ -1702,24 +1727,11 @@ MemorySpaceAssignment::Run(HloModule* module,
|
||||
VLOG(4) << "Schedule: " << module->schedule().ToString();
|
||||
MemorySpaceAssignment memory_space_assignment(module, options,
|
||||
hlo_live_range);
|
||||
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
||||
&memory_space_assignment.allocations_, options, alias_analysis,
|
||||
hlo_live_range);
|
||||
|
||||
if (options.enable_cross_program_prefetch) {
|
||||
absl::optional<BufferInterval> prefetch_candiate =
|
||||
FindCrossProgramPrefetchCandidate(alias_analysis, hlo_live_range,
|
||||
options);
|
||||
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
||||
}
|
||||
|
||||
HeapSimulator::Options heap_simulator_options;
|
||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
||||
module->schedule(), alias_analysis,
|
||||
options.size_fn, heap_simulator_options)
|
||||
.status());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
AllocationSequence allocations,
|
||||
FindAllocationSequence(module, hlo_live_range, alias_analysis, options));
|
||||
|
||||
memory_space_assignment.SetAllocationSequence(std::move(allocations));
|
||||
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
||||
memory_space_assignment.ScheduleAsynchronousCopies();
|
||||
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
|
||||
|
@ -615,6 +615,14 @@ class MemorySpaceAssignment {
|
||||
}
|
||||
}
|
||||
|
||||
// Sets allocations_. Must be set before Process() is called.
|
||||
// Uses an rvalue reference so that the caller is forced to hand over
|
||||
// ownership of the AllocationSequence, e.g.
|
||||
// SetAllocationSequence(std::move(my_allocation)).
|
||||
void SetAllocationSequence(AllocationSequence&& allocations) {
|
||||
allocations_ = std::move(allocations);
|
||||
}
|
||||
|
||||
// Process calls Process methods of the allocations after the allocations have
|
||||
// been finalized.
|
||||
Status Process();
|
||||
|
Loading…
x
Reference in New Issue
Block a user