From df3f233362a97ef1b0c3d551edb320f2aed61edf Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Fri, 11 Dec 2020 13:07:59 -0800 Subject: [PATCH] [XLA] Add mechanism to disable alternative mem allocation at uses PiperOrigin-RevId: 347059810 Change-Id: Ie625eecb08ca195163a11e5b65ec776f9e8e2036 --- tensorflow/compiler/xla/service/memory_space_assignment.cc | 3 +++ tensorflow/compiler/xla/service/memory_space_assignment.h | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index c86abfd1de6..4fe97a5c792 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -893,6 +893,9 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( const AllocationValue& value, const HloUse& use) const { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + if (!options_.is_use_allowed_in_alternate_mem_fn(use)) { + return false; + } if (use.instruction->opcode() == HloOpcode::kWhile) { HloComputation* while_body = use.instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 341bf7e9895..7bffcc25523 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -400,6 +400,8 @@ class MemorySpaceAssignment { GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; using IsAllowedInAlternateMemoryFunction = std::function; + using IsUseAllowedInAlternateMemoryFunction = + std::function; // MemorySpaceAssignment uses a notion of a slow and large default memory // space and a fast and small alternate memory space. @@ -434,6 +436,11 @@ class MemorySpaceAssignment { // the opcode) to be placed on the alternate memory. IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; + // This function can be used to prevent certain HloUses (e.g., based on + // the opcode) to be placed on the alternate memory. + IsUseAllowedInAlternateMemoryFunction is_use_allowed_in_alternate_mem_fn = + [](const HloUse&) { return true; }; + // Specifies the upper bound for number of outstanding prefetches and // evictions, -1 for unlimited. int64 max_outstanding_prefetches = -1;