[XLA:CPU] Don't parallelize in-place dynamic-update-slice.
Suppose we have out = dynamic-update-slice(in, update, indices...). If `in` and `out` are different memory locations, this is basically a memcpy, with most of the data coming from `in` and part coming from `update`. However if `in` and `out` are the same memory location, there's a faster implementation: Simply write the values from `update` over `in`/`out`. We call this an in-place dynamic-update-slice (DUS). In-place DUS is also possible for loop fusions which have a dynamic-update-slice as the root. The criterion is basically the same: The `in` operand to the dynamic-update-slice must be a parameter to the fusion, and it must share a buffer with the `out` of the DUS. Given a DUS op, we don't know whether we can implement it using the in-place algorithm until after buffer assignment. And buffer assignment necessarily occurs after all HLO transformations; it's illegal to change the graph after doing buffer assignment. So although HLO passes can sometimes look at the graph and say "this HLO can't be an in-place DUS", HLO passes *can't* say "this HLO will definitely be an in-place DUS". The job of ParallelTaskAssignment is to shard HLOs up across multiple CPU cores. To do this, it needs to know how many elements a particular HLO writes. Note that in-place and out-of-place DUS ops write different numbers of elements! This means that if we have an HLO which might be implemented as an in-place DUS, we can't shard it. Sharding an in-place DUS yields incorrect results, maybe due to out-of-bounds reads/writes. PiperOrigin-RevId: 248950558
This commit is contained in:
parent
810b454169
commit
abdee716ce
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
@ -402,6 +402,27 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
xla_test(
|
||||||
|
name = "dynamic_update_slice_test",
|
||||||
|
srcs = ["dynamic_update_slice_test.cc"],
|
||||||
|
backends = [
|
||||||
|
"cpu",
|
||||||
|
"gpu",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla:execution_options_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla/service/cpu:cpu_executable",
|
||||||
|
"//tensorflow/compiler/xla/service/cpu:parallel_task_assignment",
|
||||||
|
"//tensorflow/compiler/xla/service/cpu:target_machine_features",
|
||||||
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "dfs_hlo_visitor_with_default_test",
|
name = "dfs_hlo_visitor_with_default_test",
|
||||||
srcs = ["dfs_hlo_visitor_with_default_test.cc"],
|
srcs = ["dfs_hlo_visitor_with_default_test.cc"],
|
||||||
|
@ -905,6 +905,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
@ -135,6 +136,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
|||||||
// *) Emit custom loops (kSelectAndScatter).
|
// *) Emit custom loops (kSelectAndScatter).
|
||||||
// *) Operations that are not thread safe (like infeed and rng).
|
// *) Operations that are not thread safe (like infeed and rng).
|
||||||
// *) Tuple-shaped.
|
// *) Tuple-shaped.
|
||||||
|
// *) Operations that might be implemented as an in-place
|
||||||
|
// dynamic-update-slice, because we can't know how many output elements
|
||||||
|
// they will write (out-of-place will touch the whole output buffer, while
|
||||||
|
// in-place will only touch the updated elements).
|
||||||
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
||||||
auto opcode = instruction->opcode();
|
auto opcode = instruction->opcode();
|
||||||
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
|
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
|
||||||
@ -148,6 +153,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
|||||||
PotentiallyImplementedAsEigenConvolution(*instruction,
|
PotentiallyImplementedAsEigenConvolution(*instruction,
|
||||||
target_machine_features_)) ||
|
target_machine_features_)) ||
|
||||||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
|
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
|
||||||
|
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
|
||||||
instruction->shape().IsTuple()) {
|
instruction->shape().IsTuple()) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -125,5 +125,50 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
|
|||||||
EXPECT_FALSE(changed);
|
EXPECT_FALSE(changed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
|
||||||
|
// A dynamic-update-slice within a while loop. This construction is an easy
|
||||||
|
// way to make a DUS which can be run "in-place" (i.e. the input and output
|
||||||
|
// are the same buffer, and running the DUS only writes to the updated
|
||||||
|
// elements).
|
||||||
|
const string hlo_string = R"(
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
body {
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
one = s32[] constant(1)
|
||||||
|
ten = s32[] constant(10)
|
||||||
|
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||||
|
i = s32[] get-tuple-element(loop_carry), index=0
|
||||||
|
i_plus_ten = s32[] add(i, ten)
|
||||||
|
update = u32[1,100] get-tuple-element(loop_carry), index=1
|
||||||
|
data = u32[10000,100] get-tuple-element(loop_carry), index=2
|
||||||
|
new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero)
|
||||||
|
new_i = s32[] add(i, one)
|
||||||
|
ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
cond {
|
||||||
|
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||||
|
two = s32[] constant(2)
|
||||||
|
i = s32[] get-tuple-element(loop_carry), index=0
|
||||||
|
ROOT less-than = pred[] compare(i, two), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY test {
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
initial_i = s32[] parameter(0)
|
||||||
|
update = u32[1,100] parameter(1)
|
||||||
|
data = u32[10000,100] parameter(2)
|
||||||
|
tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data)
|
||||||
|
ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
|
||||||
|
EXPECT_FALSE(changed);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
197
tensorflow/compiler/xla/service/dynamic_update_slice_test.cc
Normal file
197
tensorflow/compiler/xla/service/dynamic_update_slice_test.cc
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class DynamicUpdateSliceTest : public HloTestBase {};
|
||||||
|
|
||||||
|
XLA_TEST_F(DynamicUpdateSliceTest, ShardedInPlaceDUS) {
|
||||||
|
// A dynamic-update-slice within a while loop. This construction is an easy
|
||||||
|
// way to make a DUS which can be run "in-place" (i.e. the input and output
|
||||||
|
// are the same buffer, and running the DUS only writes to the updated
|
||||||
|
// elements).
|
||||||
|
const char kModuleStr[] = R"(
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
body {
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
one = s32[] constant(1)
|
||||||
|
ten = s32[] constant(10)
|
||||||
|
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||||
|
i = s32[] get-tuple-element(loop_carry), index=0
|
||||||
|
i_plus_ten = s32[] add(i, ten)
|
||||||
|
update = u32[1,100] get-tuple-element(loop_carry), index=1
|
||||||
|
data = u32[10000,100] get-tuple-element(loop_carry), index=2
|
||||||
|
new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero)
|
||||||
|
new_i = s32[] add(i, one)
|
||||||
|
ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
cond {
|
||||||
|
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
|
||||||
|
two = s32[] constant(2)
|
||||||
|
i = s32[] get-tuple-element(loop_carry), index=0
|
||||||
|
ROOT less-than = pred[] compare(i, two), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY test {
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
initial_i = s32[] parameter(0)
|
||||||
|
update = u32[1,100] parameter(1)
|
||||||
|
data = u32[10000,100] parameter(2)
|
||||||
|
tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data)
|
||||||
|
ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get()));
|
||||||
|
fake_arguments[0] = LiteralUtil::CreateR0<int32>(0);
|
||||||
|
|
||||||
|
std::vector<Literal*> fake_argument_ptrs;
|
||||||
|
absl::c_transform(
|
||||||
|
fake_arguments, std::back_inserter(fake_argument_ptrs),
|
||||||
|
[](const Literal& literal) { return &const_cast<Literal&>(literal); });
|
||||||
|
|
||||||
|
ErrorSpec no_error(0, 0);
|
||||||
|
EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs, no_error));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression test for a dynamic-update-slice involved in the expansion of a
|
||||||
|
// kScatter op. Apologies for the large testcase, this proved difficult to
|
||||||
|
// reduce. The bug we're checking for occurs when the dynamic-update-slice is
|
||||||
|
// run in place but is sharded across cores by ParallelTaskAssigner.
|
||||||
|
XLA_TEST_F(DynamicUpdateSliceTest, ExpandedScatter) {
|
||||||
|
const char kModuleStr[] = R"(
|
||||||
|
HloModule TensorFlowScatter
|
||||||
|
|
||||||
|
and.reduce_sub_computation {
|
||||||
|
lhs = pred[] parameter(0)
|
||||||
|
rhs = pred[] parameter(1)
|
||||||
|
ROOT and = pred[] and(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
while_body {
|
||||||
|
param.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0)
|
||||||
|
get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
|
||||||
|
constant.4 = s32[] constant(1)
|
||||||
|
add = s32[] add(get-tuple-element.1, constant.4)
|
||||||
|
get-tuple-element.2 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(param.1), index=1
|
||||||
|
constant.8 = s32[] constant(0)
|
||||||
|
broadcast.1 = s32[5]{0} broadcast(constant.8), dimensions={}
|
||||||
|
get-tuple-element.3 = s32[16,4]{1,0} get-tuple-element(param.1), index=2
|
||||||
|
constant.5 = s32[] constant(0)
|
||||||
|
dynamic-slice = s32[1,4]{1,0} dynamic-slice(get-tuple-element.3, get-tuple-element.1, constant.5), dynamic_slice_sizes={1,4}
|
||||||
|
slice.18 = s32[1,1]{1,0} slice(dynamic-slice), slice={[0:1], [0:1]}
|
||||||
|
reshape.23 = s32[1]{0} reshape(slice.18)
|
||||||
|
reshape.4 = s32[4]{0} reshape(dynamic-slice)
|
||||||
|
slice.19 = s32[3]{0} slice(reshape.4), slice={[1:4]}
|
||||||
|
constant.6 = s32[1]{0} constant({0})
|
||||||
|
concatenate.1 = s32[5]{0} concatenate(reshape.23, slice.19, constant.6), dimensions={0}
|
||||||
|
compare.1 = pred[5]{0} compare(broadcast.1, concatenate.1), direction=LE
|
||||||
|
constant.9 = s32[5]{0} constant({7, 2, 95, 0, 0})
|
||||||
|
compare.2 = pred[5]{0} compare(constant.9, concatenate.1), direction=GE
|
||||||
|
and.1 = pred[5]{0} and(compare.1, compare.2)
|
||||||
|
constant.10 = pred[] constant(true)
|
||||||
|
reduce = pred[] reduce(and.1, constant.10), dimensions={0}, to_apply=and.reduce_sub_computation
|
||||||
|
broadcast.2 = pred[1,1,1,1,64]{4,3,2,1,0} broadcast(reduce), dimensions={}
|
||||||
|
reshape.24 = s32[] reshape(slice.18)
|
||||||
|
slice.26 = s32[1]{0} slice(reshape.4), slice={[1:2]}
|
||||||
|
reshape.10 = s32[] reshape(slice.26)
|
||||||
|
slice.27 = s32[1]{0} slice(reshape.4), slice={[2:3]}
|
||||||
|
reshape.11 = s32[] reshape(slice.27)
|
||||||
|
slice.28 = s32[1]{0} slice(reshape.4), slice={[3:4]}
|
||||||
|
reshape.12 = s32[] reshape(slice.28)
|
||||||
|
reshape.13 = s32[] reshape(constant.6)
|
||||||
|
dynamic-slice.2 = f32[1,1,1,1,64]{4,3,2,1,0} dynamic-slice(get-tuple-element.2, reshape.24, reshape.10, reshape.11, reshape.12, reshape.13), dynamic_slice_sizes={1,1,1,1,64}
|
||||||
|
get-tuple-element.4 = f32[16,64]{1,0} get-tuple-element(param.1), index=3
|
||||||
|
constant.7 = s32[] constant(0)
|
||||||
|
dynamic-slice.1 = f32[1,64]{1,0} dynamic-slice(get-tuple-element.4, get-tuple-element.1, constant.7), dynamic_slice_sizes={1,64}
|
||||||
|
reshape.28 = f32[1,1,1,1,64]{4,3,2,1,0} reshape(dynamic-slice.1)
|
||||||
|
add.1 = f32[1,1,1,1,64]{4,3,2,1,0} add(dynamic-slice.2, reshape.28)
|
||||||
|
select = f32[1,1,1,1,64]{4,3,2,1,0} select(broadcast.2, add.1, dynamic-slice.2)
|
||||||
|
reshape.29 = s32[] reshape(slice.18)
|
||||||
|
slice.29 = s32[1]{0} slice(reshape.4), slice={[1:2]}
|
||||||
|
reshape.15 = s32[] reshape(slice.29)
|
||||||
|
slice.30 = s32[1]{0} slice(reshape.4), slice={[2:3]}
|
||||||
|
reshape.16 = s32[] reshape(slice.30)
|
||||||
|
slice.31 = s32[1]{0} slice(reshape.4), slice={[3:4]}
|
||||||
|
reshape.17 = s32[] reshape(slice.31)
|
||||||
|
reshape.18 = s32[] reshape(constant.6)
|
||||||
|
dynamic-update-slice = f32[8,3,96,1,64]{4,3,2,1,0} dynamic-update-slice(get-tuple-element.2, select, reshape.29, reshape.15, reshape.16, reshape.17, reshape.18)
|
||||||
|
ROOT tuple.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(add, dynamic-update-slice, get-tuple-element.3, get-tuple-element.4)
|
||||||
|
}
|
||||||
|
|
||||||
|
while_cond {
|
||||||
|
param.0 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0)
|
||||||
|
get-tuple-element = s32[] get-tuple-element(param.0), index=0
|
||||||
|
constant.2 = s32[] constant(16)
|
||||||
|
ROOT compare = pred[] compare(get-tuple-element, constant.2), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
constant = s32[] constant(0)
|
||||||
|
z = f32[] constant(0)
|
||||||
|
b = f32[8,3,96,1,64]{4,3,2,1,0} broadcast(z), dimensions={}
|
||||||
|
i = s32[8,2,4]{2,1,0} parameter(0)
|
||||||
|
reshape = s32[16,4]{1,0} reshape(i)
|
||||||
|
u = f32[8,2,64]{2,1,0} parameter(1)
|
||||||
|
reshape.1 = f32[16,64]{1,0} reshape(u)
|
||||||
|
tuple = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(constant, b, reshape, reshape.1)
|
||||||
|
while = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) while(tuple), condition=while_cond, body=while_body
|
||||||
|
ROOT get-tuple-element.5 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(while), index=1
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
Literal updates =
|
||||||
|
Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {8, 2, 64}));
|
||||||
|
updates.PopulateWithValue(1.0f);
|
||||||
|
|
||||||
|
Literal indices =
|
||||||
|
Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {8, 2, 4}));
|
||||||
|
indices
|
||||||
|
.Populate<int>([&](absl::Span<const int64> indices) -> int {
|
||||||
|
auto i = indices[2] + indices[1] * 4 + indices[0] * 2 * 4;
|
||||||
|
switch (indices[2]) {
|
||||||
|
case 0:
|
||||||
|
return i % 8;
|
||||||
|
case 1:
|
||||||
|
return i % 3;
|
||||||
|
case 2:
|
||||||
|
return i % 96;
|
||||||
|
default:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.IgnoreError();
|
||||||
|
|
||||||
|
ErrorSpec no_error(0, 0);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
RunAndCompare(ParseAndReturnVerifiedModule(kModuleStr).ValueOrDie(),
|
||||||
|
{&indices, &updates}, no_error));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace xla
|
@ -23,6 +23,37 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace llvm_ir {
|
namespace llvm_ir {
|
||||||
|
|
||||||
|
bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) {
|
||||||
|
// Today we can't emit a dynamic-update-slice if the DUS node is parallized;
|
||||||
|
// the emitter will not emit correct code. It's possible to change this, but
|
||||||
|
// then ParallelTaskAssigner would have to somehow know whether a node *will*
|
||||||
|
// be emitted as an in-place DUS, and it can't, because it doesn't have a
|
||||||
|
// buffer assignment when it runs.
|
||||||
|
if (!instr->outer_dimension_partitions().empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Until we know the final buffer assignment, any unfused dynamic-update-slice
|
||||||
|
// might be implementable as an in-place DUS.
|
||||||
|
if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A fusion may be implementable as an in-place dynamic update slice if
|
||||||
|
// - it's a loop fusion,
|
||||||
|
// - dynamic-update-slice is the root of the fusion, and
|
||||||
|
// - operand 0 of the dynamic-update-slice is a parameter to the fusion
|
||||||
|
// (ignoring any get-tuple-element operations in the way).
|
||||||
|
if (instr->IsLoopFusion()) {
|
||||||
|
const HloInstruction* fused_root = instr->fused_expression_root();
|
||||||
|
return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
|
||||||
|
fused_root->operand(0)->LatestNonGteAncestor()->opcode() ==
|
||||||
|
HloOpcode::kParameter;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
||||||
const BufferAssignment& assignment) {
|
const BufferAssignment& assignment) {
|
||||||
CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
|
CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
|
||||||
@ -32,6 +63,29 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
|||||||
assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
|
assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||||
|
const BufferAssignment& assignment) {
|
||||||
|
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||||
|
if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
|
||||||
|
// associated operand. See if it shares an allocation with this operand.
|
||||||
|
HloInstruction* fused_root = fusion->fused_expression_root();
|
||||||
|
HloInstruction* fusion_operand;
|
||||||
|
ShapeIndex index;
|
||||||
|
std::tie(fusion_operand, index) =
|
||||||
|
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
|
||||||
|
// MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that
|
||||||
|
// fusion_operand is a parameter.
|
||||||
|
CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter);
|
||||||
|
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
||||||
|
return assignment.HasAllocationAt(operand, index) &&
|
||||||
|
assignment.HasAllocationAt(fusion, {}) &&
|
||||||
|
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
||||||
|
}
|
||||||
|
|
||||||
// Shared implementation of EmitDynamicUpdateSliceInPlace and
|
// Shared implementation of EmitDynamicUpdateSliceInPlace and
|
||||||
// EmitFusedDynamicUpdateSliceInPlace.
|
// EmitFusedDynamicUpdateSliceInPlace.
|
||||||
//
|
//
|
||||||
|
@ -30,6 +30,22 @@ namespace llvm_ir {
|
|||||||
using GeneratorForOperandIrArrays =
|
using GeneratorForOperandIrArrays =
|
||||||
std::function<std::vector<llvm_ir::IrArray>()>;
|
std::function<std::vector<llvm_ir::IrArray>()>;
|
||||||
|
|
||||||
|
// Determines whether the given instruction might be implemented as an
|
||||||
|
// in-place dynamic-update-slice after we have a buffer assignment.
|
||||||
|
//
|
||||||
|
// If this returns false, then CanUpdateDynamicSliceInPlace and
|
||||||
|
// CanEmitFusedDynamicUpdateSliceInPlace will also return false.
|
||||||
|
//
|
||||||
|
// This is useful if you want to check whether an instruction might be an
|
||||||
|
// in-place DUS during an HLO pass, at which point you don't have a buffer
|
||||||
|
// assignment.
|
||||||
|
//
|
||||||
|
// Note that simplifications to the HLO graph might change this function from
|
||||||
|
// returning false to returning true. Specifically, simplifying the contents of
|
||||||
|
// fusion nodes might cause a false->true transition. In general this isn't a
|
||||||
|
// problem by the time you're calling this function, but beware.
|
||||||
|
bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr);
|
||||||
|
|
||||||
// Checks if we can emit code for the given DynamicUpdateSlice node that updates
|
// Checks if we can emit code for the given DynamicUpdateSlice node that updates
|
||||||
// its input in place. Returns true if the dynamic-update-slice's
|
// its input in place. Returns true if the dynamic-update-slice's
|
||||||
// array-to-be-updated and output share the same BufferAllocation::Slice.
|
// array-to-be-updated and output share the same BufferAllocation::Slice.
|
||||||
@ -40,28 +56,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
|
|||||||
|
|
||||||
// Checks if the given fusion node is amenable to being implemented by
|
// Checks if the given fusion node is amenable to being implemented by
|
||||||
// EmitFusedDynamicUpdateSliceInPlace.
|
// EmitFusedDynamicUpdateSliceInPlace.
|
||||||
inline bool CanEmitFusedDynamicUpdateSliceInPlace(
|
bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||||
HloInstruction* fusion, const BufferAssignment& assignment) {
|
const BufferAssignment& assignment);
|
||||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
|
||||||
HloInstruction* fused_root = fusion->fused_expression_root();
|
|
||||||
if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice ||
|
|
||||||
!fusion->IsLoopFusion()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
|
|
||||||
// associated operand. See if it shares an allocation with this operand.
|
|
||||||
HloInstruction* fusion_operand;
|
|
||||||
ShapeIndex index;
|
|
||||||
std::tie(fusion_operand, index) =
|
|
||||||
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
|
|
||||||
if (fusion_operand->opcode() != HloOpcode::kParameter) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
|
||||||
return assignment.HasAllocationAt(operand, index) &&
|
|
||||||
assignment.HasAllocationAt(fusion, {}) &&
|
|
||||||
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
|
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
|
||||||
// where the input and output buffers share the same slice, so we can simply
|
// where the input and output buffers share the same slice, so we can simply
|
||||||
|
Loading…
Reference in New Issue
Block a user