[XLA] Refactor memory_space_assignment.cc.
Refactors IsIntervalAllowedInAlternateMemory() to a separate utils file so that it can be reused. PiperOrigin-RevId: 316554798 Change-Id: Ibc6a4cffde6a1df233d375358164b373ea4ee7a6
This commit is contained in:
parent
1b412edc89
commit
9136f5775e
@ -3304,6 +3304,15 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_space_assignment_utils",
|
||||
srcs = ["memory_space_assignment_utils.cc"],
|
||||
hdrs = ["memory_space_assignment_utils.h"],
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_space_assignment",
|
||||
srcs = ["memory_space_assignment.cc"],
|
||||
@ -3311,6 +3320,7 @@ cc_library(
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo_cost_analysis",
|
||||
":memory_space_assignment_utils",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/core/lib/math:math_util",
|
||||
],
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
namespace xla {
|
||||
|
||||
@ -597,81 +598,6 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||
return colocated_intervals;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
|
||||
const BufferInterval& interval) const {
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||
// we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't place scalars in the alternate memory.
|
||||
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a scalar.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
|
||||
// buffer, but just forwards the buffers in the either left or right side.
|
||||
// This means the two different inputs to TupleSelect must not alias, yet they
|
||||
// should be allocated in the same memory space, and both buffers must be kept
|
||||
// alive for the entire live range of TupleSelect. Instead, just don't
|
||||
// allocate TupleSelect in the alternate memory space.
|
||||
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
||||
// treated specially. We should revisit this later.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it has a tuple-select or "
|
||||
<< "add-dependency position.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Send and Recv HLOs return a request identifier. These should not be
|
||||
// allocated in the alternate memory.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if ((position.instruction->opcode() == HloOpcode::kSend ||
|
||||
position.instruction->opcode() == HloOpcode::kRecv)) {
|
||||
// TODO(berkin): Send/recv buffers need a stable buffer allocation
|
||||
// throughout sending/receiving. Disable memory space allocation for these
|
||||
// for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a send/recv buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a request identifier for "
|
||||
"send/recv.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
|
||||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||
// Disable memory space allocation for these for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
const AllocationValue& value, const HloUse& use) const {
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
@ -710,8 +636,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
|
||||
shape, parameter_time, min_use_time)) {
|
||||
VLOG(4) << "While allocation not allowed in alternate memory. "
|
||||
<< "use time = " << min_use_time
|
||||
<< ", root time = " << root_time;
|
||||
<< "use time = " << min_use_time << ", root time = " << root_time;
|
||||
return false;
|
||||
}
|
||||
// Check if there is a required assignment for the while loop output.
|
||||
@ -897,7 +822,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsIntervalAllowedInAlternateMemory(interval)) {
|
||||
if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
interval)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@ -909,10 +909,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
||||
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
||||
|
||||
// Returns true if this buffer is allowed to be placed in the alternate
|
||||
// memory.
|
||||
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
||||
|
||||
// Returns true if the use is allowed in the alternate memory.
|
||||
bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
|
||||
const HloUse& use) const;
|
||||
|
||||
@ -0,0 +1,95 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||
// we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't place scalars in the alternate memory.
|
||||
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a scalar.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
|
||||
// buffer, but just forwards the buffers in the either left or right side.
|
||||
// This means the two different inputs to TupleSelect must not alias, yet they
|
||||
// should be allocated in the same memory space, and both buffers must be kept
|
||||
// alive for the entire live range of TupleSelect. Instead, just don't
|
||||
// allocate TupleSelect in the alternate memory space.
|
||||
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
||||
// treated specially. We should revisit this later.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it has a tuple-select or "
|
||||
<< "add-dependency position.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Send and Recv HLOs return a request identifier. These should not be
|
||||
// allocated in the alternate memory.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if ((position.instruction->opcode() == HloOpcode::kSend ||
|
||||
position.instruction->opcode() == HloOpcode::kRecv)) {
|
||||
// TODO(berkin): Send/recv buffers need a stable buffer allocation
|
||||
// throughout sending/receiving. Disable memory space allocation for these
|
||||
// for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a send/recv buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a request identifier for "
|
||||
"send/recv.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
|
||||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||
// Disable memory space allocation for these for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Encapsulates common utility methods for memory space assignment.
|
||||
class MemorySpaceAssignmentUtils {
|
||||
public:
|
||||
// Returns true if this buffer is allowed to be placed in the alternate
|
||||
// memory.
|
||||
static bool IsIntervalAllowedInAlternateMemory(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_
|
||||
Loading…
x
Reference in New Issue
Block a user