[TF:XLA] Handle some of the buffer aliasing across computations in HeapSimulator.
The new modeling of subcomputations is still not entirely accurate, but probably not worth putting more work into, since TuplePointsToAnalysis will be removed from HeapSimulator soon. PiperOrigin-RevId: 209646234
This commit is contained in:
parent
dc62ab7a7c
commit
a7e961ac88
@ -627,7 +627,7 @@ Status BufferAssignment::ComputeSummaryStats() {
|
|||||||
stats_.total_allocation_bytes += allocation.size();
|
stats_.total_allocation_bytes += allocation.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only compute total fragmentation if all computations are sequential.
|
// Only compute total fragmentation if all computations have schedules.
|
||||||
SequentialHloOrdering::HloModuleSequence module_sequence;
|
SequentialHloOrdering::HloModuleSequence module_sequence;
|
||||||
for (const auto& computation : module_->computations()) {
|
for (const auto& computation : module_->computations()) {
|
||||||
const std::vector<const HloInstruction*>* sequence =
|
const std::vector<const HloInstruction*>* sequence =
|
||||||
|
@ -144,7 +144,7 @@ Status HeapSimulator::RunComputation(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// A GetTupleElement doesn't need to keep all of its operand's buffers
|
// A GetTupleElement doesn't need to keep all of its operand's buffers
|
||||||
// alive. It only needs the buffers that relate to the element its
|
// alive. It only needs the buffers that relate to the element it's
|
||||||
// extracting, and the tuple it's extracting from, but not the buffers
|
// extracting, and the tuple it's extracting from, but not the buffers
|
||||||
// for the other elements.
|
// for the other elements.
|
||||||
for (const BufferValue* buffer : points_to.element({})) {
|
for (const BufferValue* buffer : points_to.element({})) {
|
||||||
@ -277,13 +277,13 @@ Status HeapSimulator::RunComputation(
|
|||||||
*memory_by_computation_);
|
*memory_by_computation_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the whole module is sequential, we can save memory by running the
|
// If all computations in the module have been scheduled, we can save memory
|
||||||
// heap-simulation for sub-computations inline. E.g. the buffers for the
|
// by running the heap-simulation for sub-computations inline. E.g. the
|
||||||
// condition and body of a kWhile instruction are only live for the duration
|
// buffers for the condition and body of a kWhile instruction are only live
|
||||||
// of the instruction itself.
|
// for the duration of the instruction itself.
|
||||||
//
|
//
|
||||||
// The order that the sub-computations are simulated does not affect
|
// The order that the sub-computations are simulated does not affect
|
||||||
// correctness; since the whole module is sequential, we know that the
|
// correctness; since the whole module has been scheduled, we know that the
|
||||||
// sub-computations will never be run concurrently.
|
// sub-computations will never be run concurrently.
|
||||||
if (module_sequence_ != nullptr) {
|
if (module_sequence_ != nullptr) {
|
||||||
if (instruction->opcode() == HloOpcode::kCall ||
|
if (instruction->opcode() == HloOpcode::kCall ||
|
||||||
@ -380,9 +380,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer,
|
|||||||
|
|
||||||
allocated_buffers_.insert(buffer);
|
allocated_buffers_.insert(buffer);
|
||||||
const int64 size = size_fn_(*buffer);
|
const int64 size = size_fn_(*buffer);
|
||||||
algorithm_->Alloc(buffer, size);
|
const HloInstruction* instruction_to_calc_aliasing =
|
||||||
no_fragmentation_stats_->Alloc(buffer, size);
|
memory_by_computation_ == nullptr ? nullptr : instruction;
|
||||||
|
algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing);
|
||||||
|
no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing);
|
||||||
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
|
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
|
||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
@ -520,6 +521,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
|
||||||
|
const HloInstruction* instruction) {
|
||||||
|
// The output buffer of while/call/conditional is always aliased with the
|
||||||
|
// output buffer of the root instruction in the body. Don't double count.
|
||||||
|
if (instruction == nullptr ||
|
||||||
|
(instruction->opcode() != HloOpcode::kWhile &&
|
||||||
|
instruction->opcode() != HloOpcode::kCall &&
|
||||||
|
instruction->opcode() != HloOpcode::kConditional)) {
|
||||||
|
Alloc(buffer, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
||||||
|
@ -36,6 +36,7 @@ namespace xla {
|
|||||||
|
|
||||||
// Forward declare classes defined below.
|
// Forward declare classes defined below.
|
||||||
class HeapAlgorithm;
|
class HeapAlgorithm;
|
||||||
|
class NoFragmentationStatsHeap;
|
||||||
|
|
||||||
// HeapSimulator assigns buffer offsets by running a simulation of a regular
|
// HeapSimulator assigns buffer offsets by running a simulation of a regular
|
||||||
// memory heap with Alloc and Free calls. It only works for completely
|
// memory heap with Alloc and Free calls. It only works for completely
|
||||||
@ -161,7 +162,10 @@ class HeapSimulator {
|
|||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const BufferValue* shared_with_canonical);
|
const BufferValue* shared_with_canonical);
|
||||||
|
|
||||||
const std::unique_ptr<HeapAlgorithm> no_fragmentation_stats_;
|
// Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
|
||||||
|
// in which case we are calculating the same allocs/frees twice in the
|
||||||
|
// simulation.
|
||||||
|
const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_;
|
||||||
const std::unique_ptr<HeapAlgorithm> algorithm_;
|
const std::unique_ptr<HeapAlgorithm> algorithm_;
|
||||||
const BufferValue::SizeFunction size_fn_;
|
const BufferValue::SizeFunction size_fn_;
|
||||||
const Options options_;
|
const Options options_;
|
||||||
@ -216,6 +220,21 @@ class HeapAlgorithm {
|
|||||||
// Alloc allocates a buffer of 'size' bytes.
|
// Alloc allocates a buffer of 'size' bytes.
|
||||||
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
|
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
|
||||||
|
|
||||||
|
// NoFragmentationStatsHeap overrides this method.
|
||||||
|
virtual void Alloc(const BufferValue* buffer, int64 size,
|
||||||
|
const HloInstruction* instruction) {
|
||||||
|
Alloc(buffer, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Takes memory usage of subcomputations into account when calculating the
|
||||||
|
// memory usage of a computation. Currently, we don't handle buffer aliasing
|
||||||
|
// between computations entirely correctly. We are careful to not double count
|
||||||
|
// for the output buffers of whiles/conds/calls. But we don't take into
|
||||||
|
// account other aliases, such as for the while init. A more thorough solution
|
||||||
|
// would require something like BufferAssignment::BuildColocatedBufferSets.
|
||||||
|
// TODO(b/65835246):
|
||||||
|
// Since TuplePointsToAnalysis is being replaced with a module-aware alias
|
||||||
|
// analysis, it's not worth making major changes to HeapSimulator now.
|
||||||
virtual void AccountForSubcomputationMemory(
|
virtual void AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
||||||
@ -240,6 +259,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
|
|||||||
|
|
||||||
void Alloc(const BufferValue* buffer, int64 size) override;
|
void Alloc(const BufferValue* buffer, int64 size) override;
|
||||||
|
|
||||||
|
void Alloc(const BufferValue* buffer, int64 size,
|
||||||
|
const HloInstruction* instruction) override;
|
||||||
|
|
||||||
void AccountForSubcomputationMemory(
|
void AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
||||||
|
@ -244,9 +244,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
|
|||||||
*entry_computation, sequence.at(entry_computation),
|
*entry_computation, sequence.at(entry_computation),
|
||||||
*points_to_analysis, size_fn)
|
*points_to_analysis, size_fn)
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
// HeapSimulator accounts for subcomputations. The max mem doesn't change
|
// HeapSimulator accounts for subcomputations. The output buffer is aliased,
|
||||||
// because the while body isn't live during the peak.
|
// so we don't double count.
|
||||||
EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
|
EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
|
||||||
*entry_computation, sequence.at(entry_computation),
|
*entry_computation, sequence.at(entry_computation),
|
||||||
*points_to_analysis, size_fn, &memory_by_computation)
|
*points_to_analysis, size_fn, &memory_by_computation)
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
@ -350,7 +350,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
|
|||||||
TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
||||||
auto module = CreateNewModule();
|
auto module = CreateNewModule();
|
||||||
const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
|
const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
|
||||||
const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
|
|
||||||
|
|
||||||
// param != 0
|
// param != 0
|
||||||
// Needs 17 bytes
|
// Needs 17 bytes
|
||||||
@ -408,8 +407,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
|||||||
*entry_computation, sequence.at(entry_computation),
|
*entry_computation, sequence.at(entry_computation),
|
||||||
*points_to_analysis, size_fn)
|
*points_to_analysis, size_fn)
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
// HeapSimulator accounts for subcomputations
|
// HeapSimulator accounts for subcomputations. Cond is the largest one.
|
||||||
EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
|
// The output buffer of the while is aliased.
|
||||||
|
EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
|
||||||
*entry_computation, sequence.at(entry_computation),
|
*entry_computation, sequence.at(entry_computation),
|
||||||
*points_to_analysis, size_fn, &memory_by_computation)
|
*points_to_analysis, size_fn, &memory_by_computation)
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
|
Loading…
Reference in New Issue
Block a user