[TF:XLA] Handle some of the buffer aliasing across computations in HeapSimulator.

The new modeling of subcomputations is still not entirely accurate, but probably not worth putting more work into, since TuplePointsToAnalysis will be removed from HeapSimulator soon.

PiperOrigin-RevId: 209646234
This commit is contained in:
Dimitris Vardoulakis 2018-08-21 13:06:23 -07:00 committed by TensorFlower Gardener
parent dc62ab7a7c
commit a7e961ac88
4 changed files with 52 additions and 17 deletions

View File

@ -627,7 +627,7 @@ Status BufferAssignment::ComputeSummaryStats() {
stats_.total_allocation_bytes += allocation.size();
}
// Only compute total fragmentation if all computations are sequential.
// Only compute total fragmentation if all computations have schedules.
SequentialHloOrdering::HloModuleSequence module_sequence;
for (const auto& computation : module_->computations()) {
const std::vector<const HloInstruction*>* sequence =

View File

@ -144,7 +144,7 @@ Status HeapSimulator::RunComputation(
}
} else {
// A GetTupleElement doesn't need to keep all of its operand's buffers
// alive. It only needs the buffers that relate to the element its
// alive. It only needs the buffers that relate to the element it's
// extracting, and the tuple it's extracting from, but not the buffers
// for the other elements.
for (const BufferValue* buffer : points_to.element({})) {
@ -277,13 +277,13 @@ Status HeapSimulator::RunComputation(
*memory_by_computation_);
}
// If the whole module is sequential, we can save memory by running the
// heap-simulation for sub-computations inline. E.g. the buffers for the
// condition and body of a kWhile instruction are only live for the duration
// of the instruction itself.
// If all computations in the module have been scheduled, we can save memory
// by running the heap-simulation for sub-computations inline. E.g. the
// buffers for the condition and body of a kWhile instruction are only live
// for the duration of the instruction itself.
//
// The order that the sub-computations are simulated does not affect
// correctness; since the whole module is sequential, we know that the
// correctness; since the whole module has been scheduled, we know that the
// sub-computations will never be run concurrently.
if (module_sequence_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
@ -380,9 +380,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer,
allocated_buffers_.insert(buffer);
const int64 size = size_fn_(*buffer);
algorithm_->Alloc(buffer, size);
no_fragmentation_stats_->Alloc(buffer, size);
const HloInstruction* instruction_to_calc_aliasing =
memory_by_computation_ == nullptr ? nullptr : instruction;
algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing);
no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing);
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
nullptr);
}
@ -520,6 +521,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
const HloInstruction* instruction) {
// The output buffer of while/call/conditional is always aliased with the
// output buffer of the root instruction in the body. Don't double count.
if (instruction == nullptr ||
(instruction->opcode() != HloOpcode::kWhile &&
instruction->opcode() != HloOpcode::kCall &&
instruction->opcode() != HloOpcode::kConditional)) {
Alloc(buffer, size);
}
}
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&

View File

@ -36,6 +36,7 @@ namespace xla {
// Forward declare classes defined below.
class HeapAlgorithm;
class NoFragmentationStatsHeap;
// HeapSimulator assigns buffer offsets by running a simulation of a regular
// memory heap with Alloc and Free calls. It only works for completely
@ -161,7 +162,10 @@ class HeapSimulator {
const HloInstruction* instruction,
const BufferValue* shared_with_canonical);
const std::unique_ptr<HeapAlgorithm> no_fragmentation_stats_;
// Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
// in which case we are calculating the same allocs/frees twice in the
// simulation.
const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_;
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
@ -216,6 +220,21 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
// NoFragmentationStatsHeap overrides this method.
virtual void Alloc(const BufferValue* buffer, int64 size,
const HloInstruction* instruction) {
Alloc(buffer, size);
}
// Takes memory usage of subcomputations into account when calculating the
// memory usage of a computation. Currently, we don't handle buffer aliasing
// between computations entirely correctly. We are careful to not double count
// for the output buffers of whiles/conds/calls. But we don't take into
// account other aliases, such as for the while init. A more thorough solution
// would require something like BufferAssignment::BuildColocatedBufferSets.
// TODO(b/65835246):
// Since TuplePointsToAnalysis is being replaced with a module-aware alias
// analysis, it's not worth making major changes to HeapSimulator now.
virtual void AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
@ -240,6 +259,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
void Alloc(const BufferValue* buffer, int64 size) override;
void Alloc(const BufferValue* buffer, int64 size,
const HloInstruction* instruction) override;
void AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&

View File

@ -244,9 +244,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
// HeapSimulator accounts for subcomputations. The max mem doesn't change
// because the while body isn't live during the peak.
EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
// HeapSimulator accounts for subcomputations. The output buffer is aliased,
// so we don't double count.
EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
@ -350,7 +350,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
auto module = CreateNewModule();
const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
// param != 0
// Needs 17 bytes
@ -408,8 +407,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
// HeapSimulator accounts for subcomputations
EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
// HeapSimulator accounts for subcomputations. Cond is the largest one.
// The output buffer of the while is aliased.
EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
*entry_computation, sequence.at(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());