1.12.0-rc2 cherry-pick request: Various XLA scatter improvements. (#23235)

* [XLA] Update Tf2Xla bridge to use Scatter HLO.

PiperOrigin-RevId: 215687800

* [XLA:GPU] Add an implementation of scatter for GPU

This simple has a kernel that runs on every element of the updates tensor,
figure out the right indices to perform the update, and applies it with an
atomic operation.

Currently we emit a CAS for plain (i.e. non-add) updates, which is inefficient.
Also TuplePointsToAnalysis doesn't know that it should alias the operand and
output buffers of a scatter, which would avoid a copy.

PiperOrigin-RevId: 216412467

* [XLA] Allow scatter to share the operand buffer with the output

This avoids a copy.

PiperOrigin-RevId: 216437329

* [XLA:GPU] Elide the SequentialThunk when emitting scatter with no copy

We have a 1-element thunk sequence if we're not copying. That's still two
thunks and hlo profiling gets confused if it sees two thunks for the same
instruction and one of them claims to be the whole instruction.

PiperOrigin-RevId: 216448063

* [XLA:GPU] Allow input fusion into scatter

We fuse everything into the scatter now, and emit two kernels. The first kernel
fills the output buffer with the computation fused into the scatter operand.
The second kernel is a regular scatter, which also contains the fused
operations from the updates and scatter_indices inputs.

PiperOrigin-RevId: 216624225

* [XLA:GPU] Adding a test case for Scatter where GPU implementation fails.

PiperOrigin-RevId: 216798034

* [XLA:GPU] Fix scatter oob check computation

This was comparing the index after adding it to the window, and then comparing
against the window dimension. This means that the bounds check was only correct
for the first element of a window. Instead compare the scatter index, which is
the same for all elements of a window.

PiperOrigin-RevId: 216921512

* [XLA:GPU] Elide tuple roots of the entry computation

The tuple buffer is never read, so stop emitting code to fill it. A typical
root tuple consists of a H2D memcpy and a host callback, both of which are
somewhat slow.

This helps tiny models and inference benchmarks, where the host/device syncs
can be a significant part of the runtime of the entire computation.

PiperOrigin-RevId: 216968475
This commit is contained in:
Todd Wang 2018-10-24 17:46:03 -07:00 committed by GitHub
parent da1b48ddd0
commit e72c9ebe78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 609 additions and 115 deletions

View File

@ -38,12 +38,10 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
combiner,
xla::XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates));
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
absl::Span<const int64> buffer_dims =
xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@ -81,104 +79,129 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
}
}
// Shape of the non-indexed dimensions of the buffer.
std::vector<int64> buffer_shape_post_axes(
buffer_dims.begin() + num_index_dims, buffer_dims.end());
// Example of a 1-D scatter that updates two [3,1] tensors in a tensor of
// shape [3,3]:
// NOTE: ***This case will not be generated by any of the tf.scatter ops.***
//
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// updates = s32[3,2] parameter(2)
// scatter = s32[3,3] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={0},
// inserted_window_dims={1},
// scatter_dims_to_operand_dims={1},
// index_vector_dim=1
//
//
// Example of a 1-D scatter that updates two [1,3] tensors in a tensor of
// shape [3,3]:
//
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// updates = s32[2,3] parameter(2)
// scatter = s32[3,3] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={1},
// inserted_window_dims={0},
// scatter_dims_to_operand_dims={0},
// index_vector_dim=1
//
//
// Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of
// shape [3,3,2]
//
// operand = s32[3,3,2] parameter(0)
// indices = s32[2,2] parameter(1)
// updates = s32[2,2] parameter(2)
// scatter = s32[3,3,2] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={1},
// inserted_window_dims={0,1},
// scatter_dims_to_operand_dims={0,1},
// index_vector_dim=1
//
//
// Example of a scatter updating slices of shape [] in a tensor of shape [1,1]
//
// operand = s32[1,1] parameter(0)
// indices = s32[1] parameter(1)
// updates = s32[1] parameter(2)
// scatter = s32[1,1] scatter(operand, indices, updates),
// to_apply=update_computation,
// update_window_dims={},
// inserted_window_dims={0,1},
// scatter_dims_to_operand_dims={0},
// index_vector_dim=1
// Note that updates operand would be broadcasted into [1] in this case.
//
// Flatten the major dimensions of indices and updates into a single dimension
// for ease of iteration.
std::vector<int64> flat_indices_shape({num_indices});
if (indices_are_vectors) {
flat_indices_shape.push_back(num_index_dims);
xla::ScatterDimensionNumbers dim_numbers;
dim_numbers.set_index_vector_dim(indices_are_vectors
? indices_shape.dimensions_size() - 1
: indices_shape.dimensions_size());
int64 updates_rank = xla::ShapeUtil::Rank(updates_shape);
int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape);
int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
// If the rank of `updates` is 0 and does not match the expected rank of
// updates, broadcast `updates` to the expected shape of updates.
auto new_updates = updates;
std::vector<int64> expected_updates_dims(indices_dims.begin(),
indices_dims.end());
for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
expected_updates_dims.push_back(buffer_shape.dimensions(dim));
}
int64 expected_updates_rank = expected_updates_dims.size();
if (updates_rank == 0 && expected_updates_rank != 0) {
new_updates = xla::Broadcast(updates, expected_updates_dims);
TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
updates_rank = xla::ShapeUtil::Rank(updates_shape);
}
std::vector<int64> flat_updates_shape({num_indices});
flat_updates_shape.insert(flat_updates_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
// Construct the initial values of the loop-carried Tensors.
auto flat_indices = xla::Reshape(indices, flat_indices_shape);
auto flat_updates = xla::Reshape(updates, flat_updates_shape);
auto init = {flat_indices, flat_updates, buffer};
// Constructs the loop body. The implementation of scatter is essentially:
// for i in range(num_indices):
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
auto zero_index = xla::ConstantLiteral(
body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
xla::XlaOp index;
auto indices_offset = xla::Reshape(i, {1});
if (indices_are_vectors) {
indices_offset = xla::Pad(indices_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, 1}}));
index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims});
index = xla::Collapse(index, {0, 1});
} else {
index = xla::DynamicSlice(indices, indices_offset, {1});
if (updates_rank > 0) {
for (int64 i = (updates_rank - num_window_dims_in_updates);
i < updates_rank; ++i) {
dim_numbers.add_update_window_dims(i);
}
}
// Discard updates with negative indices, since some users expect this.
auto index_in_range = xla::ReduceAll(
xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
xla::CreateScalarAndComputation(xla::PRED, body_builder));
for (int64 i = 0; i < num_index_dims; ++i) {
dim_numbers.add_inserted_window_dims(i);
dim_numbers.add_scatter_dims_to_operand_dims(i);
}
// Make the index in bounds to prevent implementation defined behavior.
index = xla::Max(index, zero_index);
index = xla::Pad(
index, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
// Slice the i-th index from the updates array.
auto updates_offset = xla::Reshape(i, {1});
updates_offset = xla::Pad(
updates_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
std::vector<int64> flat_updates_slice_shape({1});
flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
auto update =
xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape);
// Unflatten the major (iteration) dimensions of the slice to their
// original shape.
std::vector<int64> updates_slice_shape(num_index_dims, 1);
updates_slice_shape.insert(updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
update = xla::Reshape(update, updates_slice_shape);
// Apply the update to the buffer. If there is a combiner, use it to merge
// the current values with the update.
auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape);
// Build the combiner computation.
xla::XlaComputation combiner_computation;
{
xla::XlaBuilder cb("scatter-combiner");
auto xla_scalar_shape =
xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {});
auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1");
if (combiner) {
update = combiner(current_value, update, body_builder);
combiner(p0, p1, &cb);
}
// Use the current value instead of the update if the index is out of
// bounds.
update = xla::Select(index_in_range, update, current_value);
// Apply the update.
buffer = xla::DynamicUpdateSlice(buffer, update, index);
combiner_computation = cb.Build().ConsumeValueOrDie();
}
return std::vector<xla::XlaOp>{indices, updates, buffer};
};
VLOG(3) << "Scatter op:";
VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape);
VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape);
VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape);
VLOG(3) << " Scatter Dimension Numbers: ";
VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim();
VLOG(3) << " update_window_dims: ["
<< absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
VLOG(3) << " inserted_window_dims: ["
<< absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
VLOG(3) << " scatter_dims_to_operand_dims: ["
<< absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
<< "]";
TF_ASSIGN_OR_RETURN(auto outputs,
XlaForEachIndex(num_indices, indices_shape.element_type(),
body_fn, init, "scatter", builder));
return outputs[2];
return xla::Scatter(buffer, indices, new_updates, combiner_computation,
dim_numbers);
}
} // namespace tensorflow

View File

@ -34,7 +34,11 @@ namespace tensorflow {
// Otherwise, `indices_are_vectors`, then indices are multidimensional and the
// minor dimension of `indices` represents a vector of indices.
//
// If any indices are negative, the corresponding update is discarded.
// If `updates` is a scalar, then it will be broadcasted into the expected shape
// of updates.
//
// If any part of the update region is out-of-bounds, the corresponding update
// is discarded.
//
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the

View File

@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
case HloOpcode::kScatter:
// TODO(b/32495713): We aren't checking the embedded computation in
// Scatter.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:

View File

@ -704,7 +704,6 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",

View File

@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kReduceWindow ||
hlo.opcode() == HloOpcode::kReshape ||
hlo.opcode() == HloOpcode::kScatter ||
hlo.opcode() == HloOpcode::kSlice ||
hlo.opcode() == HloOpcode::kTranspose;
}
@ -223,6 +224,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
// Scatter is only supported at the root of a kInput fusion.
if (producer->opcode() == HloOpcode::kScatter) {
return false;
}
// Do not fuse into reduce input fusions if the resulting kernel would suffer
// from poor data locality (due to unfriendly input layouts).
if (IsInputFusibleReduction(*consumer) &&
@ -285,7 +291,8 @@ bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
const HloInstruction* producer, const HloInstruction* consumer) {
if (IsReductionToVector(*consumer)) {
if (IsReductionToVector(*consumer) ||
consumer->opcode() == HloOpcode::kScatter) {
return HloInstruction::FusionKind::kInput;
}
if (producer->opcode() == HloOpcode::kDot ||

View File

@ -709,5 +709,44 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
}
}
TEST_F(InstructionFusionTest, FuseIntoScatter) {
auto module = ParseHloString(R"(
HloModule test_module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY FuseIntoScatter {
p0 = s32[3,3] parameter(0)
operand = s32[3,3] add(p0, p0)
p1 = s32[2] parameter(1)
indices = s32[2] add(p1, p1)
p2 = s32[2,3] parameter(2)
updates = s32[2,3] add(p2, p2)
scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
ROOT add = s32[3,3] add(scatter, scatter)
})")
.ValueOrDie();
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Add(op::Fusion(), op::Fusion()));
EXPECT_EQ(root->operand(0)->fusion_kind(),
HloInstruction::FusionKind::kInput);
EXPECT_THAT(root->operand(0)->fused_expression_root(),
op::Scatter(op::Add(), op::Add(), op::Add()));
}
} // namespace gpu
} // namespace xla

View File

@ -493,13 +493,68 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
HloInstruction* root = fusion->fused_expression_root();
// HandleFusion specializes reduction from a multi-dimensional array to a 1D
// array. The specialized version requires a initializer thunk that
// initializes the output array to the initial value of the reduce.
if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
switch (root->opcode()) {
case HloOpcode::kScatter: {
std::vector<std::unique_ptr<Thunk>> thunks;
// The initialization from 'operand' is using different loop bounds, so
// emit it in a separate kernel. Treat it like a loop fusion, writing to
// the output buffer.
{
int unroll_factor = ComputeMaxUnrollFactor(fusion);
thunks.push_back(BuildKernelThunk(
fusion, /*implements_whole_instruction=*/false, unroll_factor));
std::vector<IrArray> operand_parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter operand_elemental_emitter(
hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
GetNestedComputer());
FusedIrEmitter operand_fused_emitter(operand_parameter_arrays,
&operand_elemental_emitter);
TF_RETURN_IF_ERROR(
root->mutable_operand(0)->Accept(&operand_fused_emitter));
TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
*fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
static_cast<KernelThunk*>(thunks.back().get())));
}
// Now build the actual scatter, reading and writing to the freshly
// filled output buffer.
{
thunks.push_back(
BuildKernelThunk(fusion,
/*implements_whole_instruction=*/false));
// Spin up a new fused emitter for the scatter kernel and emit it.
std::vector<IrArray> scatter_parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
scatter_parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter scatter_elemental_emitter(
hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
GetNestedComputer());
FusedIrEmitter scatter_fused_emitter(scatter_parameter_arrays,
&scatter_elemental_emitter);
TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
TF_RETURN_IF_ERROR(EmitScatter(
thunks.back().get(), root,
/*scatter_indices_gen=*/
scatter_fused_emitter.GetGenerator(root->operand(1)),
/*updates_gen=*/
scatter_fused_emitter.GetGenerator(root->operand(2))));
}
thunk_sequence_->emplace_back(
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
return Status::OK();
}
case HloOpcode::kTuple:
case HloOpcode::kReduce: {
// HandleFusion specializes reduction from a multi-dimensional array to
// a 1D array. The specialized version requires a initializer thunk that
// initializes the output array to the initial value of the reduce.
if (root->opcode() == HloOpcode::kReduce &&
ShapeUtil::IsTuple(root->shape())) {
// TODO(b/112040122): Support variadic reduce.
@ -1672,6 +1727,14 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
}
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
// For the root node of the entry computation we can elide writing the tuple
// buffer. We can always figure out the contents of the tuples from buffer
// assignment because we insert copies to ensure non-ambiguous output buffers.
// GpuExecutable never reads the tuple buffer.
if (tuple ==
tuple->parent()->parent()->entry_computation()->root_instruction()) {
return Status::OK();
}
bool all_tuple_elements_have_buffer =
absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
return ir_emitter_context_->buffer_assignment()
@ -1958,6 +2021,178 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
return Status::OK();
}
Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
const HloInstruction* operand = scatter->operand(0);
const HloInstruction* scatter_indices = scatter->operand(1);
const HloInstruction* updates = scatter->operand(2);
std::vector<std::unique_ptr<Thunk>> thunks;
// Copy the operand into the output if it's not the same buffer already.
auto operand_buffer = GetAllocationSlice(*operand);
auto destination_buffer = GetAllocationSlice(*scatter);
if (operand_buffer != destination_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter));
}
thunks.push_back(
BuildKernelThunk(scatter,
/*implements_whole_instruction=*/thunks.empty()));
TF_RETURN_IF_ERROR(
EmitScatter(thunks.back().get(), scatter,
/*scatter_indices_gen=*/
[=](const IrArray::Index& index) {
return GetIrArray(*scatter_indices, *scatter)
.EmitReadArrayElement(index, &b_, "scatter_index");
},
/*updates_gen=*/
[=](const IrArray::Index& index) {
return GetIrArray(*updates, *scatter)
.EmitReadArrayElement(index, &b_, "update");
}));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
thunk_sequence_->push_back(std::move(thunks[0]));
} else {
thunk_sequence_->emplace_back(
absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
}
return Status::OK();
}
Status IrEmitterUnnested::EmitScatter(
Thunk* thunk, HloInstruction* scatter,
const llvm_ir::ElementGenerator& scatter_indices_gen,
const llvm_ir::ElementGenerator& updates_gen) {
const HloInstruction* operand = scatter->operand(0);
const HloInstruction* scatter_indices = scatter->operand(1);
const HloInstruction* updates = scatter->operand(2);
const ScatterDimensionNumbers& dim_numbers =
scatter->scatter_dimension_numbers();
CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
std::vector<llvm::Value*> raw_window_multidim;
std::vector<llvm::Value*> input_scatter_multidim;
std::vector<int64> raw_window_bounds;
// Partition the index into window indices and scatter indices.
for (int64 i = 0, e = index.size(); i != e; ++i) {
// For window indices also remember the window size, this comes in handy
// later.
if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
raw_window_multidim.push_back(index[i]);
raw_window_bounds.push_back(updates->shape().dimensions(i));
} else {
input_scatter_multidim.push_back(index[i]);
}
}
DCHECK_EQ(raw_window_multidim.size(),
dim_numbers.update_window_dims_size());
// Apply inserted_window_dims to the window dimensions.
int64 raw_window_multidim_idx = 0;
std::vector<llvm::Value*> input_window_multidim;
std::vector<int64> input_window_bounds;
for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) {
if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
input_window_bounds.push_back(1); // Trivial dimension.
input_window_multidim.push_back(index.GetConstantWithIndexType(0));
} else {
input_window_bounds.push_back(
raw_window_bounds[raw_window_multidim_idx]);
input_window_multidim.push_back(
raw_window_multidim[raw_window_multidim_idx]);
++raw_window_multidim_idx;
}
}
DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape()));
// Insert a 1 dimension at the end if index_vector_dim requests one.
Shape scatter_indices_shape = scatter_indices->shape();
if (dim_numbers.index_vector_dim() ==
ShapeUtil::Rank(scatter_indices_shape)) {
scatter_indices_shape.add_dimensions(1);
scatter_indices_shape.mutable_layout()->add_minor_to_major(
dim_numbers.index_vector_dim());
}
// Now load the indices corresponding to the current window from
// scatter_indices.
llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim,
index.GetType());
raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr);
llvm::Value* is_in_bounds = b_.getTrue();
for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
i != e; ++i) {
// Our index is stored along index_vector_dim, insert that into the lookup
// index into scatter_indices.
raw_scatter_index_index[dim_numbers.index_vector_dim()] =
raw_scatter_index_index.GetConstantWithIndexType(i);
int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
TF_ASSIGN_OR_RETURN(
llvm::Value* const loaded_scatter_index,
scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
scatter_indices_shape, scatter_indices->shape(), &b_)));
// And add the index to our window index. This yields the output index.
llvm::Value* casted_scatter_index =
IntCast(loaded_scatter_index, index.GetType(),
/*isSigned=*/true);
llvm::Value* dim_offset =
Add(input_window_multidim[operand_dim], casted_scatter_index);
input_window_multidim[operand_dim] = dim_offset;
// Also do the bounds check now.
int64 max_index = operand->shape().dimensions(operand_dim) -
input_window_bounds[operand_dim] + 1;
// is_in_bounds = index >= 0 && index < dim_size-window_size+1
// --> index u< dim_size-window_size+1
is_in_bounds =
And(is_in_bounds, ICmpULT(casted_scatter_index,
index.GetConstantWithIndexType(max_index)));
}
llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
// All done, now just read from the calculated input from the window, and do
// an atomic store to the calculated location in the output.
llvm_ir::IrArray::Index input_window_index(input_window_multidim,
index.GetType());
HloInstruction* output_hlo =
scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
llvm::Value* output_address =
GetIrArray(*output_hlo, *output_hlo)
.EmitArrayElementAddress(input_window_index, &b_);
llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType(
updates->shape().element_type(), module_));
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
Store(input_ir_value, input_address);
return EmitAtomicOperationForNestedComputation(
*scatter->to_apply(), output_address, input_address);
};
// Launch a kernel that reads every element in the updates tensor. We could
// also do one kernel per window instead if bounds checks turn out to be a
// bottleneck.
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
updates->shape(), ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
launch_dimensions, &b_)
.EmitLoop(IrName(scatter),
GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
&b_));
}
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
thunk_sequence_->push_back(
BuildKernelThunk(select, /*implements_whole_instruction=*/true));

View File

@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleInfeed(HloInstruction* xla_infeed) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleSort(HloInstruction* sort) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
@ -184,6 +185,14 @@ class IrEmitterUnnested : public IrEmitter {
absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
// the process. `scatter` may be fused, scatter indices are taken from
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
// expected to have the operand values in it already.
Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
const llvm_ir::ElementGenerator& scatter_indices_gen,
const llvm_ir::ElementGenerator& updates_gen);
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
// for the hlo instruction.
bool CheckAndEmitHloWithTile021(HloInstruction* hlo);

View File

@ -75,7 +75,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@ -176,8 +175,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();
pipeline.AddPass<ScatterExpander>();
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });

View File

@ -1072,6 +1072,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kScatter ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.

View File

@ -2283,6 +2283,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
computation_ = module_->entry_computation();
RunAnalysis();
HloInstruction* operand_param = computation_->parameter_instruction(0);
HloInstruction* indices_param = computation_->parameter_instruction(1);
HloInstruction* updates_param = computation_->parameter_instruction(2);
HloInstruction* scatter = computation_->root_instruction();
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
operand_param, {}, scatter, {}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
indices_param, {}, scatter, {}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
updates_param, {}, scatter, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());

View File

@ -216,6 +216,7 @@ HLO_MATCHER(Remainder);
HLO_MATCHER(Reshape);
HLO_MATCHER(Reverse);
HLO_MATCHER(Rng);
HLO_MATCHER(Scatter);
HLO_MATCHER(Select);
HLO_MATCHER(SelectAndScatter);
HLO_MATCHER(Send);

View File

@ -146,7 +146,8 @@ void HloModule::ReplaceComputations(
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow: {
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter: {
HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
replacements, instruction->to_apply(), nullptr);
if (new_arg != nullptr) {

View File

@ -71,17 +71,31 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
// profitability model for inlining is defined.
if (hlo_query::AllOperandsAreParameters(root)) {
if (root.opcode() == HloOpcode::kFusion ||
root.opcode() == HloOpcode::kParameter ||
root.opcode() == HloOpcode::kTrace) {
// Cloning not supported for these instructions.
return Status::OK();
}
VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
<< root.ToShortString();
// If the input is a constant then the shape of the constant could be
// different than the map shape. Hence, a broadcast is needed, else the
// cloned operand with new shape and operands work.
if (root.opcode() != HloOpcode::kConstant) {
if (root.opcode() == HloOpcode::kParameter) {
// If the root is a parameter, then use the corresponding operand as the
// result of the computation.
TF_RETURN_IF_ERROR(
map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
} else if (root.opcode() == HloOpcode::kConstant) {
// If the input is a constant then the shape of the constant could be
// different than the map shape. Hence, a broadcast is needed, else the
// cloned operand with new shape and operands work.
//
// The constant is in an embedded computation and needs to be recreated
// as part of the computation that the broadcast is inserted into.
HloInstruction* constant = computation_->AddInstruction(root.Clone());
HloInstruction* placed_instruction = computation_->AddInstruction(
HloInstruction::CreateBroadcast(map->shape(), constant, {}));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
} else {
std::vector<HloInstruction*> params;
for (int64 o = 0; o < root.operands().size(); o++) {
params.push_back(map->operands()[root.operand(o)->parameter_number()]);
@ -90,14 +104,6 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
root.CloneWithNewOperands(map->shape(), params));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
} else {
// The constant is in an embedded computation and needs to be recreated
// as part of the computation that the broadcast is inserted into.
HloInstruction* constant = computation_->AddInstruction(root.Clone());
HloInstruction* placed_instruction = computation_->AddInstruction(
HloInstruction::CreateBroadcast(map->shape(), constant, {}));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
}
changed_ = true;
return Status::OK();

View File

@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapParameter) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto param_builder = HloComputation::Builder(TestName());
param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
auto param_f32 = param_builder.Build();
auto builder = HloComputation::Builder("MapParamFunction");
auto lhs = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto rhs = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
auto computation = builder.Build();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEmbeddedComputation(std::move(param_f32));
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
// Verify execution on CPU.
auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(4);
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
} // namespace
} // namespace xla

View File

@ -1862,6 +1862,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kRemainder:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kScatter:
case HloOpcode::kSelect:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kShiftLeft:
@ -1899,7 +1900,6 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kReduce:
case HloOpcode::kReshape:
case HloOpcode::kRng:
case HloOpcode::kScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kAfterAll:

View File

@ -771,6 +771,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
}
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kScatter ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.

View File

@ -1010,6 +1010,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
computation_ = module_->entry_computation();
RunAnalysis();
HloInstruction* operand_param = computation_->parameter_instruction(0);
HloInstruction* indices_param = computation_->parameter_instruction(1);
HloInstruction* updates_param = computation_->parameter_instruction(2);
HloInstruction* scatter = computation_->root_instruction();
EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
operand_param, {}, scatter, {}));
EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
indices_param, {}, scatter, {}));
EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
updates_param, {}, scatter, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());

View File

@ -69,6 +69,37 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV1_WithFusedAdds) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
p0 = s32[3,3] parameter(0)
operand = s32[3,3] add(p0, p0)
p1 = s32[2] parameter(1)
indices = s32[2] add(p1, p1)
p2 = s32[2,3] parameter(2)
updates = s32[2,3] add(p2, p2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 1});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV2
@ -98,6 +129,37 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, SimpleR4) {
const char* hlo_text = R"(
HloModule SimpleR4
add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(f32[] lhs, f32[] rhs)
}
ENTRY main {
operand = f32[1,2,2,1] parameter(0)
indices = s32[1,3] parameter(1)
updates = f32[1,2,2,1] parameter(2)
ROOT scatter = f32[1,2,2,1] scatter(operand, indices, updates),
to_apply=add_f32,
update_window_dims={1,2,3},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0, 2, 1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateR4<float>({{{{0.f}, {0.f}}, {{0.f}, {0.f}}}});
Literal updates =
LiteralUtil::CreateR4<float>({{{{0.12}, {0.28}}, {{0.018}, {0.42}}}});
Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0, 0}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Add