[XLA:CPU] Don't parallelize in-place dynamic-update-slice.

Suppose we have

  out = dynamic-update-slice(in, update, indices...).

If `in` and `out` are different memory locations, this is basically a memcpy,
with most of the data coming from `in` and part coming from `update`.

However if `in` and `out` are the same memory location, there's a faster
implementation: Simply write the values from `update` over `in`/`out`.  We call
this an in-place dynamic-update-slice (DUS).

In-place DUS is also possible for loop fusions which have a
dynamic-update-slice as the root.  The criterion is basically the same: The
`in` operand to the dynamic-update-slice must be a parameter to the fusion, and
it must share a buffer with the `out` of the DUS.

Given a DUS op, we don't know whether we can implement it using the in-place
algorithm until after buffer assignment.  And buffer assignment necessarily
occurs after all HLO transformations; it's illegal to change the graph after
doing buffer assignment.  So although HLO passes can sometimes look at the
graph and say "this HLO can't be an in-place DUS", HLO passes *can't* say "this
HLO will definitely be an in-place DUS".

The job of ParallelTaskAssignment is to shard HLOs up across multiple CPU
cores.  To do this, it needs to know how many elements a particular HLO writes.
Note that in-place and out-of-place DUS ops write different numbers of
elements!  This means that if we have an HLO which might be implemented as an
in-place DUS, we can't shard it.

Sharding an in-place DUS yields incorrect results, maybe due to out-of-bounds
reads/writes.

PiperOrigin-RevId: 248950558
This commit is contained in:
Justin Lebar 2019-05-19 11:37:30 -07:00 committed by TensorFlower Gardener
parent 810b454169
commit abdee716ce
8 changed files with 343 additions and 23 deletions

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -402,6 +402,27 @@ tf_cc_test(
],
)
xla_test(
name = "dynamic_update_slice_test",
srcs = ["dynamic_update_slice_test.cc"],
backends = [
"cpu",
"gpu",
],
deps = [
":hlo_parser",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service/cpu:cpu_executable",
"//tensorflow/compiler/xla/service/cpu:parallel_task_assignment",
"//tensorflow/compiler/xla/service/cpu:target_machine_features",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
tf_cc_test(
name = "dfs_hlo_visitor_with_default_test",
srcs = ["dfs_hlo_visitor_with_default_test.cc"],

View File

@ -905,6 +905,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
namespace xla {
namespace cpu {
@ -135,6 +136,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
// *) Emit custom loops (kSelectAndScatter).
// *) Operations that are not thread safe (like infeed and rng).
// *) Tuple-shaped.
// *) Operations that might be implemented as an in-place
// dynamic-update-slice, because we can't know how many output elements
// they will write (out-of-place will touch the whole output buffer, while
// in-place will only touch the updated elements).
// TODO(b/27458679) Parallelize instructions which are skipped here.
auto opcode = instruction->opcode();
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
@ -148,6 +153,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
instruction->shape().IsTuple()) {
return 1;
}

View File

@ -125,5 +125,50 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
EXPECT_FALSE(changed);
}
TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
// A dynamic-update-slice within a while loop. This construction is an easy
// way to make a DUS which can be run "in-place" (i.e. the input and output
// are the same buffer, and running the DUS only writes to the updated
// elements).
const string hlo_string = R"(
HloModule test
body {
zero = s32[] constant(0)
one = s32[] constant(1)
ten = s32[] constant(10)
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
i = s32[] get-tuple-element(loop_carry), index=0
i_plus_ten = s32[] add(i, ten)
update = u32[1,100] get-tuple-element(loop_carry), index=1
data = u32[10000,100] get-tuple-element(loop_carry), index=2
new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero)
new_i = s32[] add(i, one)
ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data)
}
cond {
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
two = s32[] constant(2)
i = s32[] get-tuple-element(loop_carry), index=0
ROOT less-than = pred[] compare(i, two), direction=LT
}
ENTRY test {
zero = s32[] constant(0)
initial_i = s32[] parameter(0)
update = u32[1,100] parameter(1)
data = u32[10000,100] parameter(2)
tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data)
ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
EXPECT_FALSE(changed);
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,197 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
class DynamicUpdateSliceTest : public HloTestBase {};
XLA_TEST_F(DynamicUpdateSliceTest, ShardedInPlaceDUS) {
// A dynamic-update-slice within a while loop. This construction is an easy
// way to make a DUS which can be run "in-place" (i.e. the input and output
// are the same buffer, and running the DUS only writes to the updated
// elements).
const char kModuleStr[] = R"(
HloModule test
body {
zero = s32[] constant(0)
one = s32[] constant(1)
ten = s32[] constant(10)
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
i = s32[] get-tuple-element(loop_carry), index=0
i_plus_ten = s32[] add(i, ten)
update = u32[1,100] get-tuple-element(loop_carry), index=1
data = u32[10000,100] get-tuple-element(loop_carry), index=2
new_data = u32[10000,100] dynamic-update-slice(data, update, i_plus_ten, zero)
new_i = s32[] add(i, one)
ROOT tuple = (s32[], u32[1,100], u32[10000,100]) tuple(new_i, update, new_data)
}
cond {
loop_carry = (s32[], u32[1,100], u32[10000,100]) parameter(0)
two = s32[] constant(2)
i = s32[] get-tuple-element(loop_carry), index=0
ROOT less-than = pred[] compare(i, two), direction=LT
}
ENTRY test {
zero = s32[] constant(0)
initial_i = s32[] parameter(0)
update = u32[1,100] parameter(1)
data = u32[10000,100] parameter(2)
tuple = (s32[], u32[1,100], u32[10000,100]) tuple(initial_i, update, data)
ROOT while = (s32[], u32[1,100], u32[10000,100]) while(tuple), condition=cond, body=body
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get()));
fake_arguments[0] = LiteralUtil::CreateR0<int32>(0);
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
[](const Literal& literal) { return &const_cast<Literal&>(literal); });
ErrorSpec no_error(0, 0);
EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs, no_error));
}
// Regression test for a dynamic-update-slice involved in the expansion of a
// kScatter op. Apologies for the large testcase, this proved difficult to
// reduce. The bug we're checking for occurs when the dynamic-update-slice is
// run in place but is sharded across cores by ParallelTaskAssigner.
XLA_TEST_F(DynamicUpdateSliceTest, ExpandedScatter) {
const char kModuleStr[] = R"(
HloModule TensorFlowScatter
and.reduce_sub_computation {
lhs = pred[] parameter(0)
rhs = pred[] parameter(1)
ROOT and = pred[] and(lhs, rhs)
}
while_body {
param.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0)
get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
constant.4 = s32[] constant(1)
add = s32[] add(get-tuple-element.1, constant.4)
get-tuple-element.2 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(param.1), index=1
constant.8 = s32[] constant(0)
broadcast.1 = s32[5]{0} broadcast(constant.8), dimensions={}
get-tuple-element.3 = s32[16,4]{1,0} get-tuple-element(param.1), index=2
constant.5 = s32[] constant(0)
dynamic-slice = s32[1,4]{1,0} dynamic-slice(get-tuple-element.3, get-tuple-element.1, constant.5), dynamic_slice_sizes={1,4}
slice.18 = s32[1,1]{1,0} slice(dynamic-slice), slice={[0:1], [0:1]}
reshape.23 = s32[1]{0} reshape(slice.18)
reshape.4 = s32[4]{0} reshape(dynamic-slice)
slice.19 = s32[3]{0} slice(reshape.4), slice={[1:4]}
constant.6 = s32[1]{0} constant({0})
concatenate.1 = s32[5]{0} concatenate(reshape.23, slice.19, constant.6), dimensions={0}
compare.1 = pred[5]{0} compare(broadcast.1, concatenate.1), direction=LE
constant.9 = s32[5]{0} constant({7, 2, 95, 0, 0})
compare.2 = pred[5]{0} compare(constant.9, concatenate.1), direction=GE
and.1 = pred[5]{0} and(compare.1, compare.2)
constant.10 = pred[] constant(true)
reduce = pred[] reduce(and.1, constant.10), dimensions={0}, to_apply=and.reduce_sub_computation
broadcast.2 = pred[1,1,1,1,64]{4,3,2,1,0} broadcast(reduce), dimensions={}
reshape.24 = s32[] reshape(slice.18)
slice.26 = s32[1]{0} slice(reshape.4), slice={[1:2]}
reshape.10 = s32[] reshape(slice.26)
slice.27 = s32[1]{0} slice(reshape.4), slice={[2:3]}
reshape.11 = s32[] reshape(slice.27)
slice.28 = s32[1]{0} slice(reshape.4), slice={[3:4]}
reshape.12 = s32[] reshape(slice.28)
reshape.13 = s32[] reshape(constant.6)
dynamic-slice.2 = f32[1,1,1,1,64]{4,3,2,1,0} dynamic-slice(get-tuple-element.2, reshape.24, reshape.10, reshape.11, reshape.12, reshape.13), dynamic_slice_sizes={1,1,1,1,64}
get-tuple-element.4 = f32[16,64]{1,0} get-tuple-element(param.1), index=3
constant.7 = s32[] constant(0)
dynamic-slice.1 = f32[1,64]{1,0} dynamic-slice(get-tuple-element.4, get-tuple-element.1, constant.7), dynamic_slice_sizes={1,64}
reshape.28 = f32[1,1,1,1,64]{4,3,2,1,0} reshape(dynamic-slice.1)
add.1 = f32[1,1,1,1,64]{4,3,2,1,0} add(dynamic-slice.2, reshape.28)
select = f32[1,1,1,1,64]{4,3,2,1,0} select(broadcast.2, add.1, dynamic-slice.2)
reshape.29 = s32[] reshape(slice.18)
slice.29 = s32[1]{0} slice(reshape.4), slice={[1:2]}
reshape.15 = s32[] reshape(slice.29)
slice.30 = s32[1]{0} slice(reshape.4), slice={[2:3]}
reshape.16 = s32[] reshape(slice.30)
slice.31 = s32[1]{0} slice(reshape.4), slice={[3:4]}
reshape.17 = s32[] reshape(slice.31)
reshape.18 = s32[] reshape(constant.6)
dynamic-update-slice = f32[8,3,96,1,64]{4,3,2,1,0} dynamic-update-slice(get-tuple-element.2, select, reshape.29, reshape.15, reshape.16, reshape.17, reshape.18)
ROOT tuple.1 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(add, dynamic-update-slice, get-tuple-element.3, get-tuple-element.4)
}
while_cond {
param.0 = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) parameter(0)
get-tuple-element = s32[] get-tuple-element(param.0), index=0
constant.2 = s32[] constant(16)
ROOT compare = pred[] compare(get-tuple-element, constant.2), direction=LT
}
ENTRY main {
constant = s32[] constant(0)
z = f32[] constant(0)
b = f32[8,3,96,1,64]{4,3,2,1,0} broadcast(z), dimensions={}
i = s32[8,2,4]{2,1,0} parameter(0)
reshape = s32[16,4]{1,0} reshape(i)
u = f32[8,2,64]{2,1,0} parameter(1)
reshape.1 = f32[16,64]{1,0} reshape(u)
tuple = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) tuple(constant, b, reshape, reshape.1)
while = (s32[], f32[8,3,96,1,64]{4,3,2,1,0}, s32[16,4]{1,0}, f32[16,64]{1,0}) while(tuple), condition=while_cond, body=while_body
ROOT get-tuple-element.5 = f32[8,3,96,1,64]{4,3,2,1,0} get-tuple-element(while), index=1
}
)";
Literal updates =
Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {8, 2, 64}));
updates.PopulateWithValue(1.0f);
Literal indices =
Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {8, 2, 4}));
indices
.Populate<int>([&](absl::Span<const int64> indices) -> int {
auto i = indices[2] + indices[1] * 4 + indices[0] * 2 * 4;
switch (indices[2]) {
case 0:
return i % 8;
case 1:
return i % 3;
case 2:
return i % 96;
default:
return 0;
}
})
.IgnoreError();
ErrorSpec no_error(0, 0);
EXPECT_TRUE(
RunAndCompare(ParseAndReturnVerifiedModule(kModuleStr).ValueOrDie(),
{&indices, &updates}, no_error));
}
} // anonymous namespace
} // namespace xla

View File

@ -23,6 +23,37 @@ limitations under the License.
namespace xla {
namespace llvm_ir {
bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) {
// Today we can't emit a dynamic-update-slice if the DUS node is parallized;
// the emitter will not emit correct code. It's possible to change this, but
// then ParallelTaskAssigner would have to somehow know whether a node *will*
// be emitted as an in-place DUS, and it can't, because it doesn't have a
// buffer assignment when it runs.
if (!instr->outer_dimension_partitions().empty()) {
return false;
}
// Until we know the final buffer assignment, any unfused dynamic-update-slice
// might be implementable as an in-place DUS.
if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) {
return true;
}
// A fusion may be implementable as an in-place dynamic update slice if
// - it's a loop fusion,
// - dynamic-update-slice is the root of the fusion, and
// - operand 0 of the dynamic-update-slice is a parameter to the fusion
// (ignoring any get-tuple-element operations in the way).
if (instr->IsLoopFusion()) {
const HloInstruction* fused_root = instr->fused_expression_root();
return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
fused_root->operand(0)->LatestNonGteAncestor()->opcode() ==
HloOpcode::kParameter;
}
return false;
}
bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
const BufferAssignment& assignment) {
CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
@ -32,6 +63,29 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
}
bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
const BufferAssignment& assignment) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) {
return false;
}
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
// associated operand. See if it shares an allocation with this operand.
HloInstruction* fused_root = fusion->fused_expression_root();
HloInstruction* fusion_operand;
ShapeIndex index;
std::tie(fusion_operand, index) =
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
// MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that
// fusion_operand is a parameter.
CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter);
auto* operand = fusion->operand(fusion_operand->parameter_number());
return assignment.HasAllocationAt(operand, index) &&
assignment.HasAllocationAt(fusion, {}) &&
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
}
// Shared implementation of EmitDynamicUpdateSliceInPlace and
// EmitFusedDynamicUpdateSliceInPlace.
//

View File

@ -30,6 +30,22 @@ namespace llvm_ir {
using GeneratorForOperandIrArrays =
std::function<std::vector<llvm_ir::IrArray>()>;
// Determines whether the given instruction might be implemented as an
// in-place dynamic-update-slice after we have a buffer assignment.
//
// If this returns false, then CanUpdateDynamicSliceInPlace and
// CanEmitFusedDynamicUpdateSliceInPlace will also return false.
//
// This is useful if you want to check whether an instruction might be an
// in-place DUS during an HLO pass, at which point you don't have a buffer
// assignment.
//
// Note that simplifications to the HLO graph might change this function from
// returning false to returning true. Specifically, simplifying the contents of
// fusion nodes might cause a false->true transition. In general this isn't a
// problem by the time you're calling this function, but beware.
bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr);
// Checks if we can emit code for the given DynamicUpdateSlice node that updates
// its input in place. Returns true if the dynamic-update-slice's
// array-to-be-updated and output share the same BufferAllocation::Slice.
@ -40,28 +56,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
// Checks if the given fusion node is amenable to being implemented by
// EmitFusedDynamicUpdateSliceInPlace.
inline bool CanEmitFusedDynamicUpdateSliceInPlace(
HloInstruction* fusion, const BufferAssignment& assignment) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
HloInstruction* fused_root = fusion->fused_expression_root();
if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice ||
!fusion->IsLoopFusion()) {
return false;
}
// Walk DynamicUpdateSlice operand(0) to fused parameter and get its
// associated operand. See if it shares an allocation with this operand.
HloInstruction* fusion_operand;
ShapeIndex index;
std::tie(fusion_operand, index) =
fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
if (fusion_operand->opcode() != HloOpcode::kParameter) {
return false;
}
auto* operand = fusion->operand(fusion_operand->parameter_number());
return assignment.HasAllocationAt(operand, index) &&
assignment.HasAllocationAt(fusion, {}) &&
assignment.SharesSliceAtIndex(fusion, {}, operand, index);
}
bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
const BufferAssignment& assignment);
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
// where the input and output buffers share the same slice, so we can simply