Change scatter emission to go through MLIR

- For non-fused scatter, convert it to LHLO ScatterOp and use that to drive emission.
- For fused scatter, extract relevant pieces and use those to drive scatter code emission.

PiperOrigin-RevId: 338266178
Change-Id: I659d1f4aa7e13048a37a1e3f85e869bf25e84a0f
This commit is contained in:
Rahul Joshi 2020-10-21 08:19:02 -07:00 committed by TensorFlower Gardener
parent fec830a3c8
commit 0d03786e46
6 changed files with 259 additions and 99 deletions

View File

@ -134,8 +134,8 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
return Status::OK();
}
// This pass take a MLIR HLO module, convert it to XLA to perform the HLO
// optimization pipeline for the required platform, and then convert back to
// This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
// optimization pipeline for the required platform, and then converts it back to
// MLIR LHLO.
class XlaHloToLhloPass
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
@ -400,6 +400,49 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) {
return EmitFusionOp(instr).status();
}
StatusOr<mhlo::ScatterDimensionNumbers>
LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) {
auto* scatter_instr = ::xla::Cast<::xla::HloScatterInstruction>(instr);
const ::xla::ScatterDimensionNumbers& xla_scatter_dim =
scatter_instr->scatter_dimension_numbers();
auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
getI64DenseElementsAttr(xla_scatter_dim.update_window_dims()),
getI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()),
getI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()),
builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()),
module_.getContext());
return scatter_dimension_numbers;
}
StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto scatter,
CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
// copy attributes
auto* scatter_instr = ::xla::Cast<::xla::HloScatterInstruction>(instr);
TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
GetScatterDimensionNumbers(instr));
scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers);
scatter.indices_are_sortedAttr(
builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
scatter.unique_indicesAttr(
builder_.getBoolAttr(scatter_instr->unique_indices()));
// import update computation as region
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
*scatter_instr->called_computations()[0], &scatter.update_computation(),
&builder_));
return scatter;
}
Status LhloDialectEmitter::HandleScatter(HloInstruction* instr) {
return EmitScatterOp(instr).status();
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) {

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
namespace mlir {
@ -44,11 +46,20 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
::xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers(
::xla::HloInstruction* instr);
private:
template <typename OpType>
::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr);
template <typename T>
DenseIntElementsAttr getI64DenseElementsAttr(const T& container) {
return builder_.getI64TensorAttr(
{container.data(), static_cast<size_t>(container.size())});
}
tensorflow::Status DefaultAction(::xla::HloInstruction* instr) final;
// Computation parameters don't need any specific handling when they are
@ -59,6 +70,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
tensorflow::Status HandleSort(::xla::HloInstruction* instr) final;
tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final;
tensorflow::Status HandleScatter(::xla::HloInstruction* instr) final;
// Helper function that recursively visits the tuple structure in
// `current_shape`, and reconstruct a matching lmhlo::TupleOp.

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
@ -203,7 +204,7 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
}
StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
mlir::Operation* op, mlir::OperandRange operands,
mlir::Operation* op, mlir::ValueRange operands,
absl::Span<const BufferAllocation> allocations) {
const auto buffer_is_written = [op](mlir::Value operand) {
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
@ -227,6 +228,14 @@ StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
return slices;
}
bool BinarySearchDenseElementsAttr(::mlir::DenseIntElementsAttr elements,
int64 v) {
::mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
return std::binary_search(
elements.begin(), elements.end(), value,
[](const ::mlir::APInt& x, const ::mlir::APInt& y) { return x.slt(y); });
}
} // namespace
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
@ -444,7 +453,7 @@ llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op,
}
// Check the size of the internal result tensors
if (auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op)) {
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
auto result = fusion.region().walk([&](mlir::Operation* op) {
for (mlir::Value result : op->getResults()) {
if (!hlo_shape_in_range(result)) {
@ -800,12 +809,30 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
GetGeneratorForOperandIrArrays(fusion),
&scatter_elemental_emitter);
TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
TF_RETURN_IF_ERROR(EmitScatter(
thunks.back().get(), root,
/*scatter_indices_gen=*/
scatter_fused_emitter.GetGenerator(root->operand(1)),
/*updates_gen=*/
scatter_fused_emitter.GetGenerator(root->operand(2))));
CHECK_EQ(root->parent()->FusionInstruction(), fusion);
TF_ASSIGN_OR_RETURN(
const auto dim_numbers,
lhlo_scratch_emitter_.GetScatterDimensionNumbers(root));
ScatterDescriptor desc;
desc.name = IrName(root);
desc.operand_shape = root->operand(0)->shape();
desc.scatter_indices_shape = root->operand(1)->shape();
desc.updates_shape = root->operand(2)->shape();
desc.dim_numbers = dim_numbers;
desc.unique_indices = root->unique_indices();
desc.update_computation = root->called_computations()[0];
desc.output = GetIrArray(*fusion, *fusion);
desc.scatter_indices_gen =
scatter_fused_emitter.GetGenerator(root->operand(1));
desc.updates_gen =
scatter_fused_emitter.GetGenerator(root->operand(2));
desc.get_index_type = [&](int64 launch_size) {
return GetIndexTypeForKernel(root, launch_size, &b_);
};
TF_RETURN_IF_ERROR(EmitScatter(desc, thunks.back().get()));
}
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(fusion), std::move(thunks)));
@ -1233,61 +1260,118 @@ Status IrEmitterUnnested::HandleRngGetAndUpdateState(
}
Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
const HloInstruction* operand = scatter->operand(0);
const HloInstruction* scatter_indices = scatter->operand(1);
const HloInstruction* updates = scatter->operand(2);
MlirEmitterInput result;
TF_ASSIGN_OR_RETURN(auto scatter_op,
lhlo_scratch_emitter_.EmitScatterOp(scatter));
result.op = scatter_op;
result.thunk_info = GetThunkInfo(scatter);
return EmitScatterFromMlir(result);
}
Status IrEmitterUnnested::EmitScatterFromMlir(MlirEmitterInput mlir_input) {
std::vector<std::unique_ptr<Thunk>> thunks;
absl::Span<const BufferAllocation> allocations(
ir_emitter_context_->buffer_assignment().Allocations());
::mlir::lmhlo::ScatterOp scatter_op =
::mlir::cast<::mlir::lmhlo::ScatterOp>(mlir_input.op);
TF_ASSIGN_OR_RETURN(
auto operand_buffer,
GetAllocationSliceForMlir(scatter_op.operand(), allocations));
TF_ASSIGN_OR_RETURN(
auto output_buffer,
GetAllocationSliceForMlir(scatter_op.output(), allocations));
// Copy the operand into the output if it's not the same buffer already.
auto operand_buffer = GetAllocationSlice(*operand);
auto destination_buffer = GetAllocationSlice(*scatter);
if (operand_buffer != destination_buffer) {
if (operand_buffer != output_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape())));
/*destination_buffer=*/output_buffer,
/*mem_size=*/
ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType()))));
}
thunks.push_back(
BuildKernelThunk(scatter,
/*implements_whole_instruction=*/thunks.empty()));
// Create MLIR buffer slice info for all operands except the first one
// (`operand`). The code generated for scatter below assumes that the input
// operand is already copied into the output, so does not use it in codegen.
TF_ASSIGN_OR_RETURN(
std::vector<MlirBufferSlice> operand_slices,
GetMlirBufferSlices(scatter_op, scatter_op.getOperands().drop_front(),
allocations));
std::string name = mlir::GetNameFromLoc(scatter_op.getLoc());
std::vector<llvm_ir::IrArray> ir_arrays;
thunks.push_back(BuildKernelThunkForMlir(name, mlir_input.thunk_info,
operand_slices, &ir_arrays));
CHECK_EQ(ir_arrays.size(), 3);
const IrArray& scatter_indices = ir_arrays[0];
const IrArray& updates = ir_arrays[1];
const IrArray& output = ir_arrays[2];
auto get_index_type = [&](int64 launch_size) {
return GetIndexTypeForKernelFromMlir(scatter_op, launch_size, &b_);
};
TF_RETURN_IF_ERROR(EmitScatter(
thunks.back().get(), scatter,
thunks.back().get(), scatter_op, output,
/*scatter_indices_gen=*/
[=](const IrArray::Index& index) {
return GetIrArray(*scatter_indices, *scatter)
.EmitReadArrayElement(index, &b_, "scatter_index");
[&](const IrArray::Index& index) {
return scatter_indices.EmitReadArrayElement(index, &b_,
"scatter_index");
},
/*updates_gen=*/
[=](const IrArray::Index& index) {
return GetIrArray(*updates, *scatter)
.EmitReadArrayElement(index, &b_, "update");
}));
[&](const IrArray::Index& index) {
return updates.EmitReadArrayElement(index, &b_, "update");
},
/* get_index_type=*/
get_index_type));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(scatter), std::move(thunks)));
mlir_input.thunk_info, std::move(thunks)));
}
return Status::OK();
}
Status IrEmitterUnnested::EmitScatter(
Thunk* thunk, HloInstruction* scatter,
Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
const llvm_ir::IrArray& output,
const llvm_ir::ElementGenerator& scatter_indices_gen,
const llvm_ir::ElementGenerator& updates_gen) {
const HloInstruction* operand = scatter->operand(0);
const HloInstruction* scatter_indices = scatter->operand(1);
const HloInstruction* updates = scatter->operand(2);
const ScatterDimensionNumbers& dim_numbers =
scatter->scatter_dimension_numbers();
CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
const llvm_ir::ElementGenerator& updates_gen,
std::function<llvm::Type*(int64)> get_index_type) {
const Shape operand_shape = TypeToShape(scatter.operand().getType());
CHECK(
ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape));
TF_ASSIGN_OR_RETURN(
const HloComputation* update_computation,
GetOrCreateSubComputationFromRegion(&scatter.update_computation()));
ScatterDescriptor desc;
desc.name = mlir::GetNameFromLoc(scatter.getLoc());
desc.operand_shape = operand_shape;
desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType());
desc.updates_shape = TypeToShape(scatter.updates().getType());
desc.dim_numbers = scatter.scatter_dimension_numbers();
desc.unique_indices = scatter.unique_indices();
desc.update_computation = update_computation;
desc.output = output;
desc.scatter_indices_gen = scatter_indices_gen;
desc.updates_gen = updates_gen;
desc.get_index_type = get_index_type;
return EmitScatter(desc, thunk);
}
Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc,
Thunk* thunk) {
auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
std::vector<llvm::Value*> raw_window_multidim;
std::vector<llvm::Value*> input_scatter_multidim;
@ -1297,22 +1381,25 @@ Status IrEmitterUnnested::EmitScatter(
for (int64 i = 0, e = index.size(); i != e; ++i) {
// For window indices also remember the window size, this comes in handy
// later.
if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(),
i)) {
raw_window_multidim.push_back(index[i]);
raw_window_bounds.push_back(updates->shape().dimensions(i));
raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
} else {
input_scatter_multidim.push_back(index[i]);
}
}
DCHECK_EQ(raw_window_multidim.size(),
dim_numbers.update_window_dims_size());
desc.dim_numbers.update_window_dims().size());
// Apply inserted_window_dims to the window dimensions.
int64 raw_window_multidim_idx = 0;
std::vector<llvm::Value*> input_window_multidim;
std::vector<int64> input_window_bounds;
for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
for (int64 i = 0, e = desc.operand_shape.rank(); i != e; ++i) {
if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(),
i)) {
input_window_bounds.push_back(1); // Trivial dimension.
input_window_multidim.push_back(index.GetConstantWithIndexType(0));
} else {
@ -1323,14 +1410,15 @@ Status IrEmitterUnnested::EmitScatter(
++raw_window_multidim_idx;
}
}
DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
// Insert a 1 dimension at the end if index_vector_dim requests one.
Shape scatter_indices_shape = scatter_indices->shape();
if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
scatter_indices_shape.add_dimensions(1);
scatter_indices_shape.mutable_layout()->add_minor_to_major(
dim_numbers.index_vector_dim());
Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
if (desc.dim_numbers.index_vector_dim().getInt() ==
desc.scatter_indices_shape.rank()) {
scatter_indices_shape_fixed.add_dimensions(1);
scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
desc.dim_numbers.index_vector_dim().getInt());
}
// Now load the indices corresponding to the current window from
@ -1338,23 +1426,27 @@ Status IrEmitterUnnested::EmitScatter(
std::vector<llvm::Value*> raw_scatter_index_multidim =
input_scatter_multidim;
raw_scatter_index_multidim.insert(
raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(),
raw_scatter_index_multidim.begin() +
desc.dim_numbers.index_vector_dim().getInt(),
nullptr);
llvm::Value* is_in_bounds = b_.getTrue();
for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
for (int64 i = 0,
e = desc.dim_numbers.scatter_dims_to_operand_dims().size();
i != e; ++i) {
// Our index is stored along index_vector_dim, insert that into the lookup
// index into scatter_indices.
raw_scatter_index_multidim[dim_numbers.index_vector_dim()] =
raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] =
index.GetConstantWithIndexType(i);
llvm_ir::IrArray::Index raw_scatter_index_index(
raw_scatter_index_multidim, scatter_indices_shape, index.GetType());
raw_scatter_index_multidim, scatter_indices_shape_fixed,
index.GetType());
int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
int64 operand_dim =
desc.dim_numbers.scatter_dims_to_operand_dims().getValue<int64>(i);
TF_ASSIGN_OR_RETURN(
llvm::Value* const loaded_scatter_index,
scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
scatter_indices_shape, scatter_indices->shape(), &b_)));
desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
// And add the index to our window index. This yields the output index.
llvm::Value* casted_scatter_index =
IntCast(loaded_scatter_index, index.GetType(),
@ -1364,7 +1456,7 @@ Status IrEmitterUnnested::EmitScatter(
input_window_multidim[operand_dim] = dim_offset;
// Also do the bounds check now.
int64 max_index = operand->shape().dimensions(operand_dim) -
int64 max_index = desc.operand_shape.dimensions(operand_dim) -
input_window_bounds[operand_dim] + 1;
// is_in_bounds = index >= 0 && index < dim_size-window_size+1
// --> index u< dim_size-window_size+1
@ -1378,25 +1470,23 @@ Status IrEmitterUnnested::EmitScatter(
llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
// All done, now just read from the calculated input from the window, and do
// an atomic store to the calculated location in the output.
HloInstruction* output_hlo =
scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
llvm_ir::IrArray::Index input_window_index(
input_window_multidim, output_hlo->shape(), index.GetType());
input_window_multidim, desc.output.GetShape(), index.GetType());
llvm::Value* output_address =
GetIrArray(*output_hlo, *output_hlo)
.EmitArrayElementAddress(input_window_index, &b_);
desc.output.EmitArrayElementAddress(input_window_index, &b_);
llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(updates->shape().element_type(),
llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
module_),
"input_address", &b_);
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
desc.updates_gen(index));
Store(input_ir_value, input_address);
if (!scatter->unique_indices()) {
if (!desc.unique_indices) {
return EmitAtomicOperationForNestedComputation(
*scatter->to_apply(), output_address, input_address);
*desc.update_computation, output_address, input_address);
} else {
return EmitCallToNestedComputation(*scatter->to_apply(),
return EmitCallToNestedComputation(*desc.update_computation,
{output_address, input_address},
output_address);
}
@ -1406,15 +1496,14 @@ Status IrEmitterUnnested::EmitScatter(
// also do one kernel per window instead if bounds checks turn out to be a
// bottleneck.
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
updates->shape(), ir_emitter_context_->gpu_device_info());
desc.updates_shape, ir_emitter_context_->gpu_device_info());
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
launch_dimensions, &b_)
.EmitLoop(IrName(scatter),
GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
&b_));
.EmitLoop(desc.name,
desc.get_index_type(launch_dimensions.launch_bound()));
}
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
namespace xla {
@ -158,6 +159,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleRng(HloInstruction* random) override;
Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
Status HandleScatter(HloInstruction* scatter) override;
Status EmitScatterFromMlir(MlirEmitterInput mlir_input);
Status HandleSelect(HloInstruction* select) override;
Status HandleSort(HloInstruction* sort) override;
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
@ -407,16 +409,38 @@ class IrEmitterUnnested : public IrEmitter,
const llvm_ir::IrArray::Index& slice_input_index);
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
// the process. `scatter` may be fused, scatter indices are taken from
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
// expected to have the operand values in it already. If unique_indices
// is false, we will use an atomic update. Using true for unique_indices
// behaves properly only when it is guaranteed that the indices to be
// updated do not overlap. The caller is responsible for ensuring this is
// the case.
Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
// the process. Scatter indices are taken from `scatter_indices_gen`, updates
// from `updates_gen`. The output buffer is expected to have the operand
// values in it already. If unique_indices is false, we will use an atomic
// update. Using true for unique_indices behaves properly only when it is
// guaranteed that the indices to be updated do not overlap. The caller is
// responsible for ensuring this is the case.
Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
const llvm_ir::IrArray& output,
const llvm_ir::ElementGenerator& scatter_indices_gen,
const llvm_ir::ElementGenerator& updates_gen);
const llvm_ir::ElementGenerator& updates_gen,
std::function<llvm::Type*(int64)> get_index_type);
// Structure describing a scatter operation for IR emission.
// TODO(jurahul): Migrate element generators to use MLIR.
// Migrate update_computation to be an MLIR Region.
struct ScatterDescriptor {
std::string name;
Shape operand_shape;
Shape scatter_indices_shape;
Shape updates_shape;
mlir::mhlo::ScatterDimensionNumbers dim_numbers;
bool unique_indices;
const HloComputation* update_computation;
llvm_ir::IrArray output;
llvm_ir::ElementGenerator scatter_indices_gen;
llvm_ir::ElementGenerator updates_gen;
std::function<llvm::Type*(int64)> get_index_type;
};
// Emits code for an in-place scatter using the provided scatter operation
// description.
Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk);
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
// for the hlo instruction.

View File

@ -225,7 +225,7 @@ void EmitBitcodeToFile(const llvm::Module& module, absl::string_view filename) {
// for the NVPTX target.
string EmitModuleToPTX(llvm::Module* module,
llvm::TargetMachine* target_machine) {
std::string ptx; // need a std::string instead of a ::string.
std::string ptx;
{
llvm::raw_string_ostream stream(ptx);
llvm::buffer_ostream pstream(stream);

View File

@ -3,14 +3,12 @@
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
// CHECK: entry:
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [2 x i32]*
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [2 x [3 x i32]]*
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_12:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_13:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_14:.*]] = mul nuw nsw i32 %[[VAL_12]], 6
@ -75,14 +73,12 @@ ENTRY main {
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
// CHECK: entry:
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32*
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds i8, i8* %[[VAL_41:.*]], i64 0
// CHECK: %[[VAL_42:.*]] = bitcast i8* %[[VAL_40]] to i32*
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds i8, i8* %[[VAL_44:.*]], i64 0
// CHECK: %[[VAL_45:.*]] = bitcast i8* %[[VAL_43]] to [0 x i32]*
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds i8, i8* %[[VAL_47:.*]], i64 0
// CHECK: %[[VAL_48:.*]] = bitcast i8* %[[VAL_46]] to i32*
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32*
// CHECK: %[[VAL_49:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
// CHECK: %[[VAL_51:.*]] = mul nuw nsw i32 %[[VAL_49]], 1
@ -135,14 +131,12 @@ ENTRY main {
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0
// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_68:.*]] = getelementptr inbounds i8, i8* %[[VAL_69:.*]], i64 0
// CHECK: %[[VAL_70:.*]] = bitcast i8* %[[VAL_68]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_71:.*]] = getelementptr inbounds i8, i8* %[[VAL_72:.*]], i64 0
// CHECK: %[[VAL_73:.*]] = bitcast i8* %[[VAL_71]] to [2 x i32]*
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds i8, i8* %[[VAL_75:.*]], i64 0
// CHECK: %[[VAL_76:.*]] = bitcast i8* %[[VAL_74]] to [2 x [3 x i32]]*
// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0
// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]*
// CHECK: %[[VAL_77:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_79:.*]] = mul nuw nsw i32 %[[VAL_77]], 6
@ -180,7 +174,7 @@ ENTRY main {
// CHECK: atomic_op_loop_body: ; preds = %[[VAL_104]], %[[VAL_95]]
// CHECK: %[[VAL_105:.*]] = load i32, i32* %[[VAL_64]], align 4
// CHECK: store i32 %[[VAL_105]], i32* %[[VAL_63]], align 4
// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]])
// CHECK: call void @{{.+}}(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]])
// CHECK: %[[VAL_106:.*]] = load i32, i32* %[[VAL_63]], align 4
// CHECK: %[[VAL_107:.*]] = cmpxchg i32* %[[VAL_97]], i32 %[[VAL_105]], i32 %[[VAL_106]] seq_cst seq_cst
// CHECK: %[[VAL_108:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 0
@ -219,14 +213,12 @@ ENTRY main {
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
// CHECK: entry:
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]*
// CHECK: %[[VAL_121:.*]] = getelementptr inbounds i8, i8* %[[VAL_122:.*]], i64 0
// CHECK: %[[VAL_123:.*]] = bitcast i8* %[[VAL_121]] to [4 x i32]*
// CHECK: %[[VAL_124:.*]] = getelementptr inbounds i8, i8* %[[VAL_125:.*]], i64 0
// CHECK: %[[VAL_126:.*]] = bitcast i8* %[[VAL_124]] to i32*
// CHECK: %[[VAL_127:.*]] = getelementptr inbounds i8, i8* %[[VAL_128:.*]], i64 0
// CHECK: %[[VAL_129:.*]] = bitcast i8* %[[VAL_127]] to i32*
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]*
// CHECK: %[[VAL_130:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_131:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
// CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_130]], 1