diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index c86abfd1de6..4fe97a5c792 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -893,6 +893,9 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( const AllocationValue& value, const HloUse& use) const { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + if (!options_.is_use_allowed_in_alternate_mem_fn(use)) { + return false; + } if (use.instruction->opcode() == HloOpcode::kWhile) { HloComputation* while_body = use.instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 341bf7e9895..7bffcc25523 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -400,6 +400,8 @@ class MemorySpaceAssignment { GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; using IsAllowedInAlternateMemoryFunction = std::function; + using IsUseAllowedInAlternateMemoryFunction = + std::function; // MemorySpaceAssignment uses a notion of a slow and large default memory // space and a fast and small alternate memory space. @@ -434,6 +436,11 @@ class MemorySpaceAssignment { // the opcode) to be placed on the alternate memory. IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; + // This function can be used to prevent certain HloUses (e.g., based on + // the opcode) to be placed on the alternate memory. + IsUseAllowedInAlternateMemoryFunction is_use_allowed_in_alternate_mem_fn = + [](const HloUse&) { return true; }; + // Specifies the upper bound for number of outstanding prefetches and // evictions, -1 for unlimited. int64 max_outstanding_prefetches = -1;