From abdee716ce56d3e15dc01a83f46eaa6722b6fe3e Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 19 May 2019 11:37:30 -0700 Subject: [PATCH] [XLA:CPU] Don't parallelize in-place dynamic-update-slice. Suppose we have out = dynamic-update-slice(in, update, indices...). If `in` and `out` are different memory locations, this is basically a memcpy, with most of the data coming from `in` and part coming from `update`. However if `in` and `out` are the same memory location, there's a faster implementation: Simply write the values from `update` over `in`/`out`. We call this an in-place dynamic-update-slice (DUS). In-place DUS is also possible for loop fusions which have a dynamic-update-slice as the root. The criterion is basically the same: The `in` operand to the dynamic-update-slice must be a parameter to the fusion, and it must share a buffer with the `out` of the DUS. Given a DUS op, we don't know whether we can implement it using the in-place algorithm until after buffer assignment. And buffer assignment necessarily occurs after all HLO transformations; it's illegal to change the graph after doing buffer assignment. So although HLO passes can sometimes look at the graph and say "this HLO can't be an in-place DUS", HLO passes *can't* say "this HLO will definitely be an in-place DUS". The job of ParallelTaskAssignment is to shard HLOs up across multiple CPU cores. To do this, it needs to know how many elements a particular HLO writes. Note that in-place and out-of-place DUS ops write different numbers of elements! This means that if we have an HLO which might be implemented as an in-place DUS, we can't shard it. Sharding an in-place DUS yields incorrect results, maybe due to out-of-bounds reads/writes. PiperOrigin-RevId: 248950558 --- tensorflow/compiler/xla/literal.h | 2 +- tensorflow/compiler/xla/service/BUILD | 21 ++ tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../service/cpu/parallel_task_assignment.cc | 6 + .../cpu/parallel_task_assignment_test.cc | 45 ++++ .../xla/service/dynamic_update_slice_test.cc | 197 ++++++++++++++++++ .../llvm_ir/dynamic_update_slice_util.cc | 54 +++++ .../llvm_ir/dynamic_update_slice_util.h | 40 ++-- 8 files changed, 343 insertions(+), 23 deletions(-) create mode 100644 tensorflow/compiler/xla/service/dynamic_update_slice_test.cc diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index c810ae9cbae..3c53592d040 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 060e456af8f..72ec2d7a88b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -402,6 +402,27 @@ tf_cc_test( ], ) +xla_test( + name = "dynamic_update_slice_test", + srcs = ["dynamic_update_slice_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service/cpu:cpu_executable", + "//tensorflow/compiler/xla/service/cpu:parallel_task_assignment", + "//tensorflow/compiler/xla/service/cpu:target_machine_features", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 09f5c859af4..088a9b29fed 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -905,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index b894bf502ca..23312e40f7e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" namespace xla { namespace cpu { @@ -135,6 +136,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // *) Emit custom loops (kSelectAndScatter). // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. + // *) Operations that might be implemented as an in-place + // dynamic-update-slice, because we can't know how many output elements + // they will write (out-of-place will touch the whole output buffer, while + // in-place will only touch the updated elements). // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || @@ -148,6 +153,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || (opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) || + llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || instruction->shape().IsTuple()) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 35ae62b42df..e2c93568b74 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -125,5 +125,50 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) { + // A dynamic-update-slice within a while loop. This construction is an easy + // way to make a DUS which can be run "in-place" (i.e. the input and output + // are the same buffer, and running the DUS only writes to the updated + // elements). + const string hlo_string = R"( + HloModule test + + body { + zero = s32[] constant(0) + one = s32[] constant(1) + ten = s32[] constant(10) + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + i = s32[] get-tuple-element(loop_carry), index=0 + i_plus_ten = s32[] add(i, ten) + update = u32[1,100] get-tuple-element(loop_carry), index=1 + data = u32[10000,100] get-tuple-element(loop_carry), index=2 + new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero) + new_i = s32[] add(i, one) + ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data) + } + + cond { + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + two = s32[] constant(2) + i = s32[] get-tuple-element(loop_carry), index=0 + ROOT less-than = pred[] compare(i, two), direction=LT + } + + ENTRY test { + zero = s32[] constant(0) + initial_i = s32[] parameter(0) + update = u32[1,100] parameter(1) + data = u32[10000,100] parameter(2) + tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data) + ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_update_slice_test.cc b/tensorflow/compiler/xla/service/dynamic_update_slice_test.cc new file mode 100644 index 00000000000..a7caab685bf --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_update_slice_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class DynamicUpdateSliceTest : public HloTestBase {}; + +XLA_TEST_F(DynamicUpdateSliceTest, ShardedInPlaceDUS) { + // A dynamic-update-slice within a while loop. This construction is an easy + // way to make a DUS which can be run "in-place" (i.e. the input and output + // are the same buffer, and running the DUS only writes to the updated + // elements). + const char kModuleStr[] = R"( + HloModule test + + body { + zero = s32[] constant(0) + one = s32[] constant(1) + ten = s32[] constant(10) + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + i = s32[] get-tuple-element(loop_carry), index=0 + i_plus_ten = s32[] add(i, ten) + update = u32[1,100] get-tuple-element(loop_carry), index=1 + data = u32[10000,100] get-tuple-element(loop_carry), index=2 + new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero) + new_i = s32[] add(i, one) + ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data) + } + + cond { + loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0) + two = s32[] constant(2) + i = s32[] get-tuple-element(loop_carry), index=0 + ROOT less-than = pred[] compare(i, two), direction=LT + } + + ENTRY test { + zero = s32[] constant(0) + initial_i = s32[] parameter(0) + update = u32[1,100] parameter(1) + data = u32[10000,100] parameter(2) + tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data) + ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get())); + fake_arguments[0] = LiteralUtil::CreateR0(0); + + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return &const_cast(literal); }); + + ErrorSpec no_error(0, 0); + EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs, no_error)); +} + +// Regression test for a dynamic-update-slice involved in the expansion of a +// kScatter op. Apologies for the large testcase, this proved difficult to +// reduce. The bug we're checking for occurs when the dynamic-update-slice is +// run in place but is sharded across cores by ParallelTaskAssigner. +XLA_TEST_F(DynamicUpdateSliceTest, ExpandedScatter) { + const char kModuleStr[] = R"( +HloModule TensorFlowScatter + +and.reduce_sub_computation { + lhs = pred[] parameter(0) + rhs = pred[] parameter(1) + ROOT and = pred[] and(lhs, rhs) +} + +while_body { + param.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + constant.4 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.4) + get-tuple-element.2 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(param.1), index=1 + constant.8 = s32[] constant(0) + broadcast.1 = s32[5]{0} broadcast(constant.8), dimensions={} + get-tuple-element.3 = s32[16,4]{1,0} get-tuple-element(param.1), index=2 + constant.5 = s32[] constant(0) + dynamic-slice = s32[1,4]{1,0} dynamic-slice(get-tuple-element.3, get-tuple-element.1, constant.5), dynamic_slice_sizes={1,4} + slice.18 = s32[1,1]{1,0} slice(dynamic-slice), slice={[0:1], [0:1]} + reshape.23 = s32[1]{0} reshape(slice.18) + reshape.4 = s32[4]{0} reshape(dynamic-slice) + slice.19 = s32[3]{0} slice(reshape.4), slice={[1:4]} + constant.6 = s32[1]{0} constant({0}) + concatenate.1 = s32[5]{0} concatenate(reshape.23, slice.19, constant.6), dimensions={0} + compare.1 = pred[5]{0} compare(broadcast.1, concatenate.1), direction=LE + constant.9 = s32[5]{0} constant({7, 2, 95, 0, 0}) + compare.2 = pred[5]{0} compare(constant.9, concatenate.1), direction=GE + and.1 = pred[5]{0} and(compare.1, compare.2) + constant.10 = pred[] constant(true) + reduce = pred[] reduce(and.1, constant.10), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2 = pred[1,1,1,1,64]{4,3,2,1,0} broadcast(reduce), dimensions={} + reshape.24 = s32[] reshape(slice.18) + slice.26 = s32[1]{0} slice(reshape.4), slice={[1:2]} + reshape.10 = s32[] reshape(slice.26) + slice.27 = s32[1]{0} slice(reshape.4), slice={[2:3]} + reshape.11 = s32[] reshape(slice.27) + slice.28 = s32[1]{0} slice(reshape.4), slice={[3:4]} + reshape.12 = s32[] reshape(slice.28) + reshape.13 = s32[] reshape(constant.6) + dynamic-slice.2 = f32[1,1,1,1,64]{4,3,2,1,0} dynamic-slice(get-tuple-element.2, reshape.24, reshape.10, reshape.11, reshape.12, reshape.13), dynamic_slice_sizes={1,1,1,1,64} + get-tuple-element.4 = f32[16,64]{1,0} get-tuple-element(param.1), index=3 + constant.7 = s32[] constant(0) + dynamic-slice.1 = f32[1,64]{1,0} dynamic-slice(get-tuple-element.4, get-tuple-element.1, constant.7), dynamic_slice_sizes={1,64} + reshape.28 = f32[1,1,1,1,64]{4,3,2,1,0} reshape(dynamic-slice.1) + add.1 = f32[1,1,1,1,64]{4,3,2,1,0} add(dynamic-slice.2, reshape.28) + select = f32[1,1,1,1,64]{4,3,2,1,0} select(broadcast.2, add.1, dynamic-slice.2) + reshape.29 = s32[] reshape(slice.18) + slice.29 = s32[1]{0} slice(reshape.4), slice={[1:2]} + reshape.15 = s32[] reshape(slice.29) + slice.30 = s32[1]{0} slice(reshape.4), slice={[2:3]} + reshape.16 = s32[] reshape(slice.30) + slice.31 = s32[1]{0} slice(reshape.4), slice={[3:4]} + reshape.17 = s32[] reshape(slice.31) + reshape.18 = s32[] reshape(constant.6) + dynamic-update-slice = f32[8,3,96,1,64]{4,3,2,1,0} dynamic-update-slice(get-tuple-element.2, select, reshape.29, reshape.15, reshape.16, reshape.17, reshape.18) + ROOT tuple.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(add, dynamic-update-slice, get-tuple-element.3, get-tuple-element.4) +} + +while_cond { + param.0 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0) + get-tuple-element = s32[] get-tuple-element(param.0), index=0 + constant.2 = s32[] constant(16) + ROOT compare = pred[] compare(get-tuple-element, constant.2), direction=LT +} + +ENTRY main { + constant = s32[] constant(0) + z = f32[] constant(0) + b = f32[8,3,96,1,64]{4,3,2,1,0} broadcast(z), dimensions={} + i = s32[8,2,4]{2,1,0} parameter(0) + reshape = s32[16,4]{1,0} reshape(i) + u = f32[8,2,64]{2,1,0} parameter(1) + reshape.1 = f32[16,64]{1,0} reshape(u) + tuple = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(constant, b, reshape, reshape.1) + while = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) while(tuple), condition=while_cond, body=while_body + ROOT get-tuple-element.5 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(while), index=1 +} +)"; + + Literal updates = + Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {8, 2, 64})); + updates.PopulateWithValue(1.0f); + + Literal indices = + Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {8, 2, 4})); + indices + .Populate([&](absl::Span indices) -> int { + auto i = indices[2] + indices[1] * 4 + indices[0] * 2 * 4; + switch (indices[2]) { + case 0: + return i % 8; + case 1: + return i % 3; + case 2: + return i % 96; + default: + return 0; + } + }) + .IgnoreError(); + + ErrorSpec no_error(0, 0); + EXPECT_TRUE( + RunAndCompare(ParseAndReturnVerifiedModule(kModuleStr).ValueOrDie(), + {&indices, &updates}, no_error)); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 4974cb57db3..ba199f35712 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -23,6 +23,37 @@ limitations under the License. namespace xla { namespace llvm_ir { +bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) { + // Today we can't emit a dynamic-update-slice if the DUS node is parallized; + // the emitter will not emit correct code. It's possible to change this, but + // then ParallelTaskAssigner would have to somehow know whether a node *will* + // be emitted as an in-place DUS, and it can't, because it doesn't have a + // buffer assignment when it runs. + if (!instr->outer_dimension_partitions().empty()) { + return false; + } + + // Until we know the final buffer assignment, any unfused dynamic-update-slice + // might be implementable as an in-place DUS. + if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) { + return true; + } + + // A fusion may be implementable as an in-place dynamic update slice if + // - it's a loop fusion, + // - dynamic-update-slice is the root of the fusion, and + // - operand 0 of the dynamic-update-slice is a parameter to the fusion + // (ignoring any get-tuple-element operations in the way). + if (instr->IsLoopFusion()) { + const HloInstruction* fused_root = instr->fused_expression_root(); + return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice && + fused_root->operand(0)->LatestNonGteAncestor()->opcode() == + HloOpcode::kParameter; + } + + return false; +} + bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, const BufferAssignment& assignment) { CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); @@ -32,6 +63,29 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, assignment.SharesTopLevelSlice(dynamic_update_slice, operand); } +bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, + const BufferAssignment& assignment) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) { + return false; + } + + // Walk DynamicUpdateSlice operand(0) to fused parameter and get its + // associated operand. See if it shares an allocation with this operand. + HloInstruction* fused_root = fusion->fused_expression_root(); + HloInstruction* fusion_operand; + ShapeIndex index; + std::tie(fusion_operand, index) = + fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); + // MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that + // fusion_operand is a parameter. + CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter); + auto* operand = fusion->operand(fusion_operand->parameter_number()); + return assignment.HasAllocationAt(operand, index) && + assignment.HasAllocationAt(fusion, {}) && + assignment.SharesSliceAtIndex(fusion, {}, operand, index); +} + // Shared implementation of EmitDynamicUpdateSliceInPlace and // EmitFusedDynamicUpdateSliceInPlace. // diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index c4da28229d0..70dc368d5d7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -30,6 +30,22 @@ namespace llvm_ir { using GeneratorForOperandIrArrays = std::function()>; +// Determines whether the given instruction might be implemented as an +// in-place dynamic-update-slice after we have a buffer assignment. +// +// If this returns false, then CanUpdateDynamicSliceInPlace and +// CanEmitFusedDynamicUpdateSliceInPlace will also return false. +// +// This is useful if you want to check whether an instruction might be an +// in-place DUS during an HLO pass, at which point you don't have a buffer +// assignment. +// +// Note that simplifications to the HLO graph might change this function from +// returning false to returning true. Specifically, simplifying the contents of +// fusion nodes might cause a false->true transition. In general this isn't a +// problem by the time you're calling this function, but beware. +bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr); + // Checks if we can emit code for the given DynamicUpdateSlice node that updates // its input in place. Returns true if the dynamic-update-slice's // array-to-be-updated and output share the same BufferAllocation::Slice. @@ -40,28 +56,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // Checks if the given fusion node is amenable to being implemented by // EmitFusedDynamicUpdateSliceInPlace. -inline bool CanEmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, const BufferAssignment& assignment) { - CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); - HloInstruction* fused_root = fusion->fused_expression_root(); - if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || - !fusion->IsLoopFusion()) { - return false; - } - // Walk DynamicUpdateSlice operand(0) to fused parameter and get its - // associated operand. See if it shares an allocation with this operand. - HloInstruction* fusion_operand; - ShapeIndex index; - std::tie(fusion_operand, index) = - fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); - if (fusion_operand->opcode() != HloOpcode::kParameter) { - return false; - } - auto* operand = fusion->operand(fusion_operand->parameter_number()); - return assignment.HasAllocationAt(operand, index) && - assignment.HasAllocationAt(fusion, {}) && - assignment.SharesSliceAtIndex(fusion, {}, operand, index); -} +bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, + const BufferAssignment& assignment); // Emits IR for running the given dynamic-update-slice op in-place -- that is, // where the input and output buffers share the same slice, so we can simply