[XLA] Refactor memory_space_assignment.cc.

Refactors IsIntervalAllowedInAlternateMemory() to a separate utils file so that it can be reused.

PiperOrigin-RevId: 316554798
Change-Id: Ibc6a4cffde6a1df233d375358164b373ea4ee7a6
This commit is contained in:
A. Unique TensorFlower 2020-06-15 15:13:45 -07:00 committed by TensorFlower Gardener
parent 1b412edc89
commit 9136f5775e
5 changed files with 143 additions and 82 deletions

View File

@ -3304,6 +3304,15 @@ tf_cc_test(
],
)
cc_library(
name = "memory_space_assignment_utils",
srcs = ["memory_space_assignment_utils.cc"],
hdrs = ["memory_space_assignment_utils.h"],
deps = [
":heap_simulator",
],
)
cc_library(
name = "memory_space_assignment",
srcs = ["memory_space_assignment.cc"],
@ -3311,6 +3320,7 @@ cc_library(
deps = [
":heap_simulator",
":hlo_cost_analysis",
":memory_space_assignment_utils",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/core/lib/math:math_util",
],

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace xla {
@ -597,81 +598,6 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
return colocated_intervals;
}
bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
const BufferInterval& interval) const {
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just pointers)
// we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a tuple.";
return false;
}
// Don't place scalars in the alternate memory.
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a scalar.";
return false;
}
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
// buffer, but just forwards the buffers in the either left or right side.
// This means the two different inputs to TupleSelect must not alias, yet they
// should be allocated in the same memory space, and both buffers must be kept
// alive for the entire live range of TupleSelect. Instead, just don't
// allocate TupleSelect in the alternate memory space.
// TODO(berkin): Not allocating add-dependencies either since they need to be
// treated specially. We should revisit this later.
for (const HloPosition& position : interval.buffer->positions()) {
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
position.instruction->opcode() == HloOpcode::kAddDependency) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it has a tuple-select or "
<< "add-dependency position.";
return false;
}
}
// Send and Recv HLOs return a request identifier. These should not be
// allocated in the alternate memory.
for (const HloPosition& position : interval.buffer->positions()) {
if ((position.instruction->opcode() == HloOpcode::kSend ||
position.instruction->opcode() == HloOpcode::kRecv)) {
// TODO(berkin): Send/recv buffers need a stable buffer allocation
// throughout sending/receiving. Disable memory space allocation for these
// for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a send/recv buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a request identifier for "
"send/recv.";
return false;
}
}
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
// Disable memory space allocation for these for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
}
}
}
return true;
}
bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
const AllocationValue& value, const HloUse& use) const {
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
@ -710,8 +636,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, parameter_time, min_use_time)) {
VLOG(4) << "While allocation not allowed in alternate memory. "
<< "use time = " << min_use_time
<< ", root time = " << root_time;
<< "use time = " << min_use_time << ", root time = " << root_time;
return false;
}
// Check if there is a required assignment for the while loop output.
@ -897,7 +822,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
if (!IsIntervalAllowedInAlternateMemory(interval)) {
if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
interval)) {
continue;
}

View File

@ -909,10 +909,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
// Returns true if the use is allowed in the alternate memory.
bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
const HloUse& use) const;

View File

@ -0,0 +1,95 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
namespace xla {
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just pointers)
// we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a tuple.";
return false;
}
// Don't place scalars in the alternate memory.
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a scalar.";
return false;
}
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
// buffer, but just forwards the buffers in the either left or right side.
// This means the two different inputs to TupleSelect must not alias, yet they
// should be allocated in the same memory space, and both buffers must be kept
// alive for the entire live range of TupleSelect. Instead, just don't
// allocate TupleSelect in the alternate memory space.
// TODO(berkin): Not allocating add-dependencies either since they need to be
// treated specially. We should revisit this later.
for (const HloPosition& position : interval.buffer->positions()) {
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
position.instruction->opcode() == HloOpcode::kAddDependency) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it has a tuple-select or "
<< "add-dependency position.";
return false;
}
}
// Send and Recv HLOs return a request identifier. These should not be
// allocated in the alternate memory.
for (const HloPosition& position : interval.buffer->positions()) {
if ((position.instruction->opcode() == HloOpcode::kSend ||
position.instruction->opcode() == HloOpcode::kRecv)) {
// TODO(berkin): Send/recv buffers need a stable buffer allocation
// throughout sending/receiving. Disable memory space allocation for these
// for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a send/recv buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a request identifier for "
"send/recv.";
return false;
}
}
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
// Disable memory space allocation for these for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
}
}
}
return true;
}
} // namespace xla

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_
#include "tensorflow/compiler/xla/service/heap_simulator.h"
namespace xla {
// Encapsulates common utility methods for memory space assignment.
class MemorySpaceAssignmentUtils {
public:
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
static bool IsIntervalAllowedInAlternateMemory(
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_