Refactors MemorySpaceAssignment to separate out the creation of the AllocationSequence.
PiperOrigin-RevId: 306322847 Change-Id: Iecccef3441c38c08f5ec42a80af778803a1d898b
This commit is contained in:
parent
867673d63b
commit
4d847bd055
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -1689,6 +1688,32 @@ FindCrossProgramPrefetchCandidate(
|
|||||||
}
|
}
|
||||||
return *best_candidate;
|
return *best_candidate;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finds an AllocationSequence for placing buffers in alternate memory using the
|
||||||
|
// AlternateMemoryBestFitHeap algorithm.
|
||||||
|
StatusOr<MemorySpaceAssignment::AllocationSequence> FindAllocationSequence(
|
||||||
|
HloModule* module, const HloLiveRange& hlo_live_range,
|
||||||
|
const HloAliasAnalysis& alias_analysis,
|
||||||
|
const MemorySpaceAssignment::Options& options) {
|
||||||
|
MemorySpaceAssignment::AllocationSequence allocations;
|
||||||
|
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
||||||
|
&allocations, options, alias_analysis, hlo_live_range);
|
||||||
|
|
||||||
|
if (options.enable_cross_program_prefetch) {
|
||||||
|
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
|
||||||
|
prefetch_candiate = FindCrossProgramPrefetchCandidate(
|
||||||
|
alias_analysis, hlo_live_range, options);
|
||||||
|
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
||||||
|
}
|
||||||
|
|
||||||
|
HeapSimulator::Options heap_simulator_options;
|
||||||
|
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||||
|
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
||||||
|
module->schedule(), alias_analysis,
|
||||||
|
options.size_fn, heap_simulator_options)
|
||||||
|
.status());
|
||||||
|
return std::move(allocations);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
||||||
@ -1702,24 +1727,11 @@ MemorySpaceAssignment::Run(HloModule* module,
|
|||||||
VLOG(4) << "Schedule: " << module->schedule().ToString();
|
VLOG(4) << "Schedule: " << module->schedule().ToString();
|
||||||
MemorySpaceAssignment memory_space_assignment(module, options,
|
MemorySpaceAssignment memory_space_assignment(module, options,
|
||||||
hlo_live_range);
|
hlo_live_range);
|
||||||
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
TF_ASSIGN_OR_RETURN(
|
||||||
&memory_space_assignment.allocations_, options, alias_analysis,
|
AllocationSequence allocations,
|
||||||
hlo_live_range);
|
FindAllocationSequence(module, hlo_live_range, alias_analysis, options));
|
||||||
|
|
||||||
if (options.enable_cross_program_prefetch) {
|
|
||||||
absl::optional<BufferInterval> prefetch_candiate =
|
|
||||||
FindCrossProgramPrefetchCandidate(alias_analysis, hlo_live_range,
|
|
||||||
options);
|
|
||||||
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
|
||||||
}
|
|
||||||
|
|
||||||
HeapSimulator::Options heap_simulator_options;
|
|
||||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
|
||||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
|
||||||
module->schedule(), alias_analysis,
|
|
||||||
options.size_fn, heap_simulator_options)
|
|
||||||
.status());
|
|
||||||
|
|
||||||
|
memory_space_assignment.SetAllocationSequence(std::move(allocations));
|
||||||
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
||||||
memory_space_assignment.ScheduleAsynchronousCopies();
|
memory_space_assignment.ScheduleAsynchronousCopies();
|
||||||
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
|
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
|
||||||
|
@ -615,6 +615,14 @@ class MemorySpaceAssignment {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sets allocations_. Must be set before Process() is called.
|
||||||
|
// Uses an rvalue reference so that the caller is forced to hand over
|
||||||
|
// ownership of the AllocationSequence, e.g.
|
||||||
|
// SetAllocationSequence(std::move(my_allocation)).
|
||||||
|
void SetAllocationSequence(AllocationSequence&& allocations) {
|
||||||
|
allocations_ = std::move(allocations);
|
||||||
|
}
|
||||||
|
|
||||||
// Process calls Process methods of the allocations after the allocations have
|
// Process calls Process methods of the allocations after the allocations have
|
||||||
// been finalized.
|
// been finalized.
|
||||||
Status Process();
|
Status Process();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user