diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 40c237c5e6d..acd35cbc153 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3304,6 +3304,15 @@ tf_cc_test( ], ) +cc_library( + name = "memory_space_assignment_utils", + srcs = ["memory_space_assignment_utils.cc"], + hdrs = ["memory_space_assignment_utils.h"], + deps = [ + ":heap_simulator", + ], +) + cc_library( name = "memory_space_assignment", srcs = ["memory_space_assignment.cc"], @@ -3311,6 +3320,7 @@ cc_library( deps = [ ":heap_simulator", ":hlo_cost_analysis", + ":memory_space_assignment_utils", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core/lib/math:math_util", ], diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 21baaf1c7d5..388a2e18f38 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h" #include "tensorflow/core/lib/math/math_util.h" namespace xla { @@ -597,81 +598,6 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( return colocated_intervals; } -bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( - const BufferInterval& interval) const { - // If the buffer is a tuple, don't use this algorithm for now. The buffers - // that are pointed to by the tuple will still use this algorithm. Because - // tuples are cheap to place in the alternate memory (they are just pointers) - // we don't need to use prefetch/evict logic. - if (interval.buffer->shape().IsTuple()) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a tuple."; - return false; - } - - // Don't place scalars in the alternate memory. - if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a scalar."; - return false; - } - - // The semantics of TupleSelect are weird: TupleSelect doesn't define a - // buffer, but just forwards the buffers in the either left or right side. - // This means the two different inputs to TupleSelect must not alias, yet they - // should be allocated in the same memory space, and both buffers must be kept - // alive for the entire live range of TupleSelect. Instead, just don't - // allocate TupleSelect in the alternate memory space. - // TODO(berkin): Not allocating add-dependencies either since they need to be - // treated specially. We should revisit this later. - for (const HloPosition& position : interval.buffer->positions()) { - if (position.instruction->opcode() == HloOpcode::kTupleSelect || - position.instruction->opcode() == HloOpcode::kAddDependency) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it has a tuple-select or " - << "add-dependency position."; - return false; - } - } - - // Send and Recv HLOs return a request identifier. These should not be - // allocated in the alternate memory. - for (const HloPosition& position : interval.buffer->positions()) { - if ((position.instruction->opcode() == HloOpcode::kSend || - position.instruction->opcode() == HloOpcode::kRecv)) { - // TODO(berkin): Send/recv buffers need a stable buffer allocation - // throughout sending/receiving. Disable memory space allocation for these - // for now. - if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a send/recv buffer."; - return false; - } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a request identifier for " - "send/recv."; - return false; - } - } - - if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart || - position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { - // Disable memory space allocation for these for now. - if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a collective-permute buffer."; - return false; - } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() - << " in default mem because it is a collective-permute buffer."; - return false; - } - } - } - - return true; -} - bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( const AllocationValue& value, const HloUse& use) const { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); @@ -710,8 +636,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( shape, parameter_time, min_use_time)) { VLOG(4) << "While allocation not allowed in alternate memory. " - << "use time = " << min_use_time - << ", root time = " << root_time; + << "use time = " << min_use_time << ", root time = " << root_time; return false; } // Check if there is a required assignment for the while loop output. @@ -897,7 +822,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { continue; } - if (!IsIntervalAllowedInAlternateMemory(interval)) { + if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + interval)) { continue; } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index b8f47e73b8c..f9e5738d17e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -909,10 +909,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( const MemorySpaceAssignment::AllocationSequence& allocations, int64 time); - // Returns true if this buffer is allowed to be placed in the alternate - // memory. - bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const; - // Returns true if the use is allowed in the alternate memory. bool IsUseAllowedInAlternateMemory(const AllocationValue& value, const HloUse& use) const; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc new file mode 100644 index 00000000000..0215f007c9c --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h" + +namespace xla { + +bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { + // If the buffer is a tuple, don't use this algorithm for now. The buffers + // that are pointed to by the tuple will still use this algorithm. Because + // tuples are cheap to place in the alternate memory (they are just pointers) + // we don't need to use prefetch/evict logic. + if (interval.buffer->shape().IsTuple()) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a tuple."; + return false; + } + + // Don't place scalars in the alternate memory. + if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a scalar."; + return false; + } + + // The semantics of TupleSelect are weird: TupleSelect doesn't define a + // buffer, but just forwards the buffers in the either left or right side. + // This means the two different inputs to TupleSelect must not alias, yet they + // should be allocated in the same memory space, and both buffers must be kept + // alive for the entire live range of TupleSelect. Instead, just don't + // allocate TupleSelect in the alternate memory space. + // TODO(berkin): Not allocating add-dependencies either since they need to be + // treated specially. We should revisit this later. + for (const HloPosition& position : interval.buffer->positions()) { + if (position.instruction->opcode() == HloOpcode::kTupleSelect || + position.instruction->opcode() == HloOpcode::kAddDependency) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it has a tuple-select or " + << "add-dependency position."; + return false; + } + } + + // Send and Recv HLOs return a request identifier. These should not be + // allocated in the alternate memory. + for (const HloPosition& position : interval.buffer->positions()) { + if ((position.instruction->opcode() == HloOpcode::kSend || + position.instruction->opcode() == HloOpcode::kRecv)) { + // TODO(berkin): Send/recv buffers need a stable buffer allocation + // throughout sending/receiving. Disable memory space allocation for these + // for now. + if (position.index == ShapeIndex({0})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a send/recv buffer."; + return false; + } else if (position.index == ShapeIndex({1})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a request identifier for " + "send/recv."; + return false; + } + } + + if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart || + position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { + // Disable memory space allocation for these for now. + if (position.index == ShapeIndex({0})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a collective-permute buffer."; + return false; + } else if (position.index == ShapeIndex({1})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a collective-permute buffer."; + return false; + } + } + } + + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.h b/tensorflow/compiler/xla/service/memory_space_assignment_utils.h new file mode 100644 index 00000000000..651ac107c25 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_ + +#include "tensorflow/compiler/xla/service/heap_simulator.h" + +namespace xla { + +// Encapsulates common utility methods for memory space assignment. +class MemorySpaceAssignmentUtils { + public: + // Returns true if this buffer is allowed to be placed in the alternate + // memory. + static bool IsIntervalAllowedInAlternateMemory( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_