[XLA:CPU] Don't parallelize in-place dynamic-update-slice.
Suppose we have out = dynamic-update-slice(in, update, indices...). If `in` and `out` are different memory locations, this is basically a memcpy, with most of the data coming from `in` and part coming from `update`. However if `in` and `out` are the same memory location, there's a faster implementation: Simply write the values from `update` over `in`/`out`. We call this an in-place dynamic-update-slice (DUS). In-place DUS is also possible for loop fusions which have a dynamic-update-slice as the root. The criterion is basically the same: The `in` operand to the dynamic-update-slice must be a parameter to the fusion, and it must share a buffer with the `out` of the DUS. Given a DUS op, we don't know whether we can implement it using the in-place algorithm until after buffer assignment. And buffer assignment necessarily occurs after all HLO transformations; it's illegal to change the graph after doing buffer assignment. So although HLO passes can sometimes look at the graph and say "this HLO can't be an in-place DUS", HLO passes *can't* say "this HLO will definitely be an in-place DUS". The job of ParallelTaskAssignment is to shard HLOs up across multiple CPU cores. To do this, it needs to know how many elements a particular HLO writes. Note that in-place and out-of-place DUS ops write different numbers of elements! This means that if we have an HLO which might be implemented as an in-place DUS, we can't shard it. Sharding an in-place DUS yields incorrect results, maybe due to out-of-bounds reads/writes. PiperOrigin-RevId: 248950558
This commit is contained in:
parent
810b454169
commit
abdee716ce
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
@ -402,6 +402,27 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "dynamic_update_slice_test",
|
||||
srcs = ["dynamic_update_slice_test.cc"],
|
||||
backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_executable",
|
||||
"//tensorflow/compiler/xla/service/cpu:parallel_task_assignment",
|
||||
"//tensorflow/compiler/xla/service/cpu:target_machine_features",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "dfs_hlo_visitor_with_default_test",
|
||||
srcs = ["dfs_hlo_visitor_with_default_test.cc"],
|
||||
|
@ -905,6 +905,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
@ -135,6 +136,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
||||
// *) Emit custom loops (kSelectAndScatter).
|
||||
// *) Operations that are not thread safe (like infeed and rng).
|
||||
// *) Tuple-shaped.
|
||||
// *) Operations that might be implemented as an in-place
|
||||
// dynamic-update-slice, because we can't know how many output elements
|
||||
// they will write (out-of-place will touch the whole output buffer, while
|
||||
// in-place will only touch the updated elements).
|
||||
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
||||
auto opcode = instruction->opcode();
|
||||
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
|
||||
@ -148,6 +153,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
||||
PotentiallyImplementedAsEigenConvolution(*instruction,
|
||||
target_machine_features_)) ||
|
||||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
|
||||
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
|
||||
instruction->shape().IsTuple()) {
|
||||
return 1;
|
||||
}
|
||||
|
@ -125,5 +125,50 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
|
||||
// A dynamic-update-slice within a while loop. This construction is an easy
|
||||
// way to make a DUS which can be run "in-place" (i.e. the input and output
|
||||
// are the same buffer, and running the DUS only writes to the updated
|
||||
// elements).
|
||||
const string hlo_string = R"(
|
||||
HloModule test
|
||||
|
||||
body {
|
||||
zero = s32[] constant(0)
|
||||
one = s32[] constant(1)
|
||||
ten = s32[] constant(10)
|
||||
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||
i = s32[] get-tuple-element(loop_carry), index=0
|
||||
i_plus_ten = s32[] add(i, ten)
|
||||
update = u32[1,100] get-tuple-element(loop_carry), index=1
|
||||
data = u32[10000,100] get-tuple-element(loop_carry), index=2
|
||||
new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero)
|
||||
new_i = s32[] add(i, one)
|
||||
ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data)
|
||||
}
|
||||
|
||||
cond {
|
||||
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||
two = s32[] constant(2)
|
||||
i = s32[] get-tuple-element(loop_carry), index=0
|
||||
ROOT less-than = pred[] compare(i, two), direction=LT
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
zero = s32[] constant(0)
|
||||
initial_i = s32[] parameter(0)
|
||||
update = u32[1,100] parameter(1)
|
||||
data = u32[10000,100] parameter(2)
|
||||
tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data)
|
||||
ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
197
tensorflow/compiler/xla/service/dynamic_update_slice_test.cc
Normal file
197
tensorflow/compiler/xla/service/dynamic_update_slice_test.cc
Normal file
@ -0,0 +1,197 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class DynamicUpdateSliceTest : public HloTestBase {};
|
||||
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, ShardedInPlaceDUS) {
|
||||
// A dynamic-update-slice within a while loop. This construction is an easy
|
||||
// way to make a DUS which can be run "in-place" (i.e. the input and output
|
||||
// are the same buffer, and running the DUS only writes to the updated
|
||||
// elements).
|
||||
const char kModuleStr[] = R"(
|
||||
HloModule test
|
||||
|
||||
body {
|
||||
zero = s32[] constant(0)
|
||||
one = s32[] constant(1)
|
||||
ten = s32[] constant(10)
|
||||
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||
i = s32[] get-tuple-element(loop_carry), index=0
|
||||
i_plus_ten = s32[] add(i, ten)
|
||||
update = u32[1,100] get-tuple-element(loop_carry), index=1
|
||||
data = u32[10000,100] get-tuple-element(loop_carry), index=2
|
||||
new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero)
|
||||
new_i = s32[] add(i, one)
|
||||
ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data)
|
||||
}
|
||||
|
||||
cond {
|
||||
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||
two = s32[] constant(2)
|
||||
i = s32[] get-tuple-element(loop_carry), index=0
|
||||
ROOT less-than = pred[] compare(i, two), direction=LT
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
zero = s32[] constant(0)
|
||||
initial_i = s32[] parameter(0)
|
||||
update = u32[1,100] parameter(1)
|
||||
data = u32[10000,100] parameter(2)
|
||||
tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data)
|
||||
ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get()));
|
||||
fake_arguments[0] = LiteralUtil::CreateR0<int32>(0);
|
||||
|
||||
std::vector<Literal*> fake_argument_ptrs;
|
||||
absl::c_transform(
|
||||
fake_arguments, std::back_inserter(fake_argument_ptrs),
|
||||
[](const Literal& literal) { return &const_cast<Literal&>(literal); });
|
||||
|
||||
ErrorSpec no_error(0, 0);
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs, no_error));
|
||||
}
|
||||
|
||||
// Regression test for a dynamic-update-slice involved in the expansion of a
|
||||
// kScatter op. Apologies for the large testcase, this proved difficult to
|
||||
// reduce. The bug we're checking for occurs when the dynamic-update-slice is
|
||||
// run in place but is sharded across cores by ParallelTaskAssigner.
|
||||
XLA_TEST_F(DynamicUpdateSliceTest, ExpandedScatter) {
|
||||
const char kModuleStr[] = R"(
|
||||
HloModule TensorFlowScatter
|
||||
|
||||
and.reduce_sub_computation {
|
||||
lhs = pred[] parameter(0)
|
||||
rhs = pred[] parameter(1)
|
||||
ROOT and = pred[] and(lhs, rhs)
|
||||
}
|
||||
|
||||
while_body {
|
||||
param.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
|
||||
constant.4 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.4)
|
||||
get-tuple-element.2 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(param.1), index=1
|
||||
constant.8 = s32[] constant(0)
|
||||
broadcast.1 = s32[5]{0} broadcast(constant.8), dimensions={}
|
||||
get-tuple-element.3 = s32[16,4]{1,0} get-tuple-element(param.1), index=2
|
||||
constant.5 = s32[] constant(0)
|
||||
dynamic-slice = s32[1,4]{1,0} dynamic-slice(get-tuple-element.3, get-tuple-element.1, constant.5), dynamic_slice_sizes={1,4}
|
||||
slice.18 = s32[1,1]{1,0} slice(dynamic-slice), slice={[0:1], [0:1]}
|
||||
reshape.23 = s32[1]{0} reshape(slice.18)
|
||||
reshape.4 = s32[4]{0} reshape(dynamic-slice)
|
||||
slice.19 = s32[3]{0} slice(reshape.4), slice={[1:4]}
|
||||
constant.6 = s32[1]{0} constant({0})
|
||||
concatenate.1 = s32[5]{0} concatenate(reshape.23, slice.19, constant.6), dimensions={0}
|
||||
compare.1 = pred[5]{0} compare(broadcast.1, concatenate.1), direction=LE
|
||||
constant.9 = s32[5]{0} constant({7, 2, 95, 0, 0})
|
||||
compare.2 = pred[5]{0} compare(constant.9, concatenate.1), direction=GE
|
||||
and.1 = pred[5]{0} and(compare.1, compare.2)
|
||||
constant.10 = pred[] constant(true)
|
||||
reduce = pred[] reduce(and.1, constant.10), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||
broadcast.2 = pred[1,1,1,1,64]{4,3,2,1,0} broadcast(reduce), dimensions={}
|
||||
reshape.24 = s32[] reshape(slice.18)
|
||||
slice.26 = s32[1]{0} slice(reshape.4), slice={[1:2]}
|
||||
reshape.10 = s32[] reshape(slice.26)
|
||||
slice.27 = s32[1]{0} slice(reshape.4), slice={[2:3]}
|
||||
reshape.11 = s32[] reshape(slice.27)
|
||||
slice.28 = s32[1]{0} slice(reshape.4), slice={[3:4]}
|
||||
reshape.12 = s32[] reshape(slice.28)
|
||||
reshape.13 = s32[] reshape(constant.6)
|
||||
dynamic-slice.2 = f32[1,1,1,1,64]{4,3,2,1,0} dynamic-slice(get-tuple-element.2, reshape.24, reshape.10, reshape.11, reshape.12, reshape.13), dynamic_slice_sizes={1,1,1,1,64}
|
||||
get-tuple-element.4 = f32[16,64]{1,0} get-tuple-element(param.1), index=3
|
||||
constant.7 = s32[] constant(0)
|
||||
dynamic-slice.1 = f32[1,64]{1,0} dynamic-slice(get-tuple-element.4, get-tuple-element.1, constant.7), dynamic_slice_sizes={1,64}
|
||||
reshape.28 = f32[1,1,1,1,64]{4,3,2,1,0} reshape(dynamic-slice.1)
|
||||
add.1 = f32[1,1,1,1,64]{4,3,2,1,0} add(dynamic-slice.2, reshape.28)
|
||||
select = f32[1,1,1,1,64]{4,3,2,1,0} select(broadcast.2, add.1, dynamic-slice.2)
|
||||
reshape.29 = s32[] reshape(slice.18)
|
||||
slice.29 = s32[1]{0} slice(reshape.4), slice={[1:2]}
|
||||
reshape.15 = s32[] reshape(slice.29)
|
||||
slice.30 = s32[1]{0} slice(reshape.4), slice={[2:3]}
|
||||
reshape.16 = s32[] reshape(slice.30)
|
||||
slice.31 = s32[1]{0} slice(reshape.4), slice={[3:4]}
|
||||
reshape.17 = s32[] reshape(slice.31)
|
||||
reshape.18 = s32[] reshape(constant.6)
|
||||
dynamic-update-slice = f32[8,3,96,1,64]{4,3,2,1,0} dynamic-update-slice(get-tuple-element.2, select, reshape.29, reshape.15, reshape.16, reshape.17, reshape.18)
|
||||
ROOT tuple.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(add, dynamic-update-slice, get-tuple-element.3, get-tuple-element.4)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
param.0 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0)
|
||||
get-tuple-element = s32[] get-tuple-element(param.0), index=0
|
||||
constant.2 = s32[] constant(16)
|
||||
ROOT compare = pred[] compare(get-tuple-element, constant.2), direction=LT
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
constant = s32[] constant(0)
|
||||
z = f32[] constant(0)
|
||||
b = f32[8,3,96,1,64]{4,3,2,1,0} broadcast(z), dimensions={}
|
||||
i = s32[8,2,4]{2,1,0} parameter(0)
|
||||
reshape = s32[16,4]{1,0} reshape(i)
|
||||
u = f32[8,2,64]{2,1,0} parameter(1)
|
||||
reshape.1 = f32[16,64]{1,0} reshape(u)
|
||||
tuple = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(constant, b, reshape, reshape.1)
|
||||
while = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT get-tuple-element.5 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(while), index=1
|
||||
}
|
||||
)";
|
||||
|
||||
Literal updates =
|
||||
Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {8, 2, 64}));
|
||||
updates.PopulateWithValue(1.0f);
|
||||
|
||||
Literal indices =
|
||||
Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {8, 2, 4}));
|
||||
indices
|
||||
.Populate<int>([&](absl::Span<const int64> indices) -> int {
|
||||
auto i = indices[2] + indices[1] * 4 + indices[0] * 2 * 4;
|
||||
switch (indices[2]) {
|
||||
case 0:
|
||||
return i % 8;
|
||||
case 1:
|
||||
return i % 3;
|
||||
case 2:
|
||||
return i % 96;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
})
|
||||
.IgnoreError();
|
||||
|
||||
ErrorSpec no_error(0, 0);
|
||||
EXPECT_TRUE(
|
||||
RunAndCompare(ParseAndReturnVerifiedModule(kModuleStr).ValueOrDie(),
|
||||
{&indices, &updates}, no_error));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
@ -23,6 +23,37 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace llvm_ir {
|
||||
|
||||
bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) {
|
||||
// Today we can't emit a dynamic-update-slice if the DUS node is parallized;
|
||||
// the emitter will not emit correct code. It's possible to change this, but
|
||||
// then ParallelTaskAssigner would have to somehow know whether a node *will*
|
||||
// be emitted as an in-place DUS, and it can't, because it doesn't have a
|
||||
// buffer assignment when it runs.
|
||||
if (!instr->outer_dimension_partitions().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Until we know the final buffer assignment, any unfused dynamic-update-slice
|
||||
// might be implementable as an in-place DUS.
|
||||
if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// A fusion may be implementable as an in-place dynamic update slice if
|
||||
// - it's a loop fusion,
|
||||
// - dynamic-update-slice is the root of the fusion, and
|
||||
// - operand 0 of the dynamic-update-slice is a parameter to the fusion
|
||||
// (ignoring any get-tuple-element operations in the way).
|
||||
if (instr->IsLoopFusion()) {
|
||||
const HloInstruction* fused_root = instr->fused_expression_root();
|
||||
return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
|
||||
fused_root->operand(0)->LatestNonGteAncestor()->opcode() ==
|
||||
HloOpcode::kParameter;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
||||
const BufferAssignment& assignment) {
|
||||
CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
|
||||
@ -32,6 +63,29 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
||||
assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
|
||||
}
|
||||
|
||||
bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
const BufferAssignment& assignment) {
|
||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||
if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
|
||||
// associated operand. See if it shares an allocation with this operand.
|
||||
HloInstruction* fused_root = fusion->fused_expression_root();
|
||||
HloInstruction* fusion_operand;
|
||||
ShapeIndex index;
|
||||
std::tie(fusion_operand, index) =
|
||||
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
|
||||
// MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that
|
||||
// fusion_operand is a parameter.
|
||||
CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter);
|
||||
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
||||
return assignment.HasAllocationAt(operand, index) &&
|
||||
assignment.HasAllocationAt(fusion, {}) &&
|
||||
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
||||
}
|
||||
|
||||
// Shared implementation of EmitDynamicUpdateSliceInPlace and
|
||||
// EmitFusedDynamicUpdateSliceInPlace.
|
||||
//
|
||||
|
@ -30,6 +30,22 @@ namespace llvm_ir {
|
||||
using GeneratorForOperandIrArrays =
|
||||
std::function<std::vector<llvm_ir::IrArray>()>;
|
||||
|
||||
// Determines whether the given instruction might be implemented as an
|
||||
// in-place dynamic-update-slice after we have a buffer assignment.
|
||||
//
|
||||
// If this returns false, then CanUpdateDynamicSliceInPlace and
|
||||
// CanEmitFusedDynamicUpdateSliceInPlace will also return false.
|
||||
//
|
||||
// This is useful if you want to check whether an instruction might be an
|
||||
// in-place DUS during an HLO pass, at which point you don't have a buffer
|
||||
// assignment.
|
||||
//
|
||||
// Note that simplifications to the HLO graph might change this function from
|
||||
// returning false to returning true. Specifically, simplifying the contents of
|
||||
// fusion nodes might cause a false->true transition. In general this isn't a
|
||||
// problem by the time you're calling this function, but beware.
|
||||
bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr);
|
||||
|
||||
// Checks if we can emit code for the given DynamicUpdateSlice node that updates
|
||||
// its input in place. Returns true if the dynamic-update-slice's
|
||||
// array-to-be-updated and output share the same BufferAllocation::Slice.
|
||||
@ -40,28 +56,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
||||
|
||||
// Checks if the given fusion node is amenable to being implemented by
|
||||
// EmitFusedDynamicUpdateSliceInPlace.
|
||||
inline bool CanEmitFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion, const BufferAssignment& assignment) {
|
||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||
HloInstruction* fused_root = fusion->fused_expression_root();
|
||||
if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice ||
|
||||
!fusion->IsLoopFusion()) {
|
||||
return false;
|
||||
}
|
||||
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
|
||||
// associated operand. See if it shares an allocation with this operand.
|
||||
HloInstruction* fusion_operand;
|
||||
ShapeIndex index;
|
||||
std::tie(fusion_operand, index) =
|
||||
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
|
||||
if (fusion_operand->opcode() != HloOpcode::kParameter) {
|
||||
return false;
|
||||
}
|
||||
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
||||
return assignment.HasAllocationAt(operand, index) &&
|
||||
assignment.HasAllocationAt(fusion, {}) &&
|
||||
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
||||
}
|
||||
bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
const BufferAssignment& assignment);
|
||||
|
||||
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
|
||||
// where the input and output buffers share the same slice, so we can simply
|
||||
|
Loading…
Reference in New Issue
Block a user