Change scatter emission to go through MLIR
- For non-fused scatter, convert it to LHLO ScatterOp and use that to drive emission. - For fused scatter, extract relevant pieces and use those to drive scatter code emission. PiperOrigin-RevId: 338266178 Change-Id: I659d1f4aa7e13048a37a1e3f85e869bf25e84a0f
This commit is contained in:
parent
fec830a3c8
commit
0d03786e46
@ -134,8 +134,8 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This pass take a MLIR HLO module, convert it to XLA to perform the HLO
|
||||
// optimization pipeline for the required platform, and then convert back to
|
||||
// This pass takes an MLIR HLO module, converts it to XLA to perform the HLO
|
||||
// optimization pipeline for the required platform, and then converts it back to
|
||||
// MLIR LHLO.
|
||||
class XlaHloToLhloPass
|
||||
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
|
||||
@ -400,6 +400,49 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) {
|
||||
return EmitFusionOp(instr).status();
|
||||
}
|
||||
|
||||
StatusOr<mhlo::ScatterDimensionNumbers>
|
||||
LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) {
|
||||
auto* scatter_instr = ::xla::Cast<::xla::HloScatterInstruction>(instr);
|
||||
|
||||
const ::xla::ScatterDimensionNumbers& xla_scatter_dim =
|
||||
scatter_instr->scatter_dimension_numbers();
|
||||
auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
|
||||
getI64DenseElementsAttr(xla_scatter_dim.update_window_dims()),
|
||||
getI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()),
|
||||
getI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()),
|
||||
builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()),
|
||||
module_.getContext());
|
||||
return scatter_dimension_numbers;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto scatter,
|
||||
CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
|
||||
|
||||
// copy attributes
|
||||
auto* scatter_instr = ::xla::Cast<::xla::HloScatterInstruction>(instr);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers,
|
||||
GetScatterDimensionNumbers(instr));
|
||||
scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers);
|
||||
scatter.indices_are_sortedAttr(
|
||||
builder_.getBoolAttr(scatter_instr->indices_are_sorted()));
|
||||
scatter.unique_indicesAttr(
|
||||
builder_.getBoolAttr(scatter_instr->unique_indices()));
|
||||
|
||||
// import update computation as region
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*scatter_instr->called_computations()[0], &scatter.update_computation(),
|
||||
&builder_));
|
||||
|
||||
return scatter;
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleScatter(HloInstruction* instr) {
|
||||
return EmitScatterOp(instr).status();
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& shape_index) {
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -44,11 +46,20 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
|
||||
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers(
|
||||
::xla::HloInstruction* instr);
|
||||
|
||||
private:
|
||||
template <typename OpType>
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr);
|
||||
|
||||
template <typename T>
|
||||
DenseIntElementsAttr getI64DenseElementsAttr(const T& container) {
|
||||
return builder_.getI64TensorAttr(
|
||||
{container.data(), static_cast<size_t>(container.size())});
|
||||
}
|
||||
|
||||
tensorflow::Status DefaultAction(::xla::HloInstruction* instr) final;
|
||||
|
||||
// Computation parameters don't need any specific handling when they are
|
||||
@ -59,6 +70,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
|
||||
tensorflow::Status HandleSort(::xla::HloInstruction* instr) final;
|
||||
tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final;
|
||||
tensorflow::Status HandleScatter(::xla::HloInstruction* instr) final;
|
||||
|
||||
// Helper function that recursively visits the tuple structure in
|
||||
// `current_shape`, and reconstruct a matching lmhlo::TupleOp.
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
@ -203,7 +204,7 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
|
||||
}
|
||||
|
||||
StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
|
||||
mlir::Operation* op, mlir::OperandRange operands,
|
||||
mlir::Operation* op, mlir::ValueRange operands,
|
||||
absl::Span<const BufferAllocation> allocations) {
|
||||
const auto buffer_is_written = [op](mlir::Value operand) {
|
||||
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
|
||||
@ -227,6 +228,14 @@ StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
|
||||
return slices;
|
||||
}
|
||||
|
||||
bool BinarySearchDenseElementsAttr(::mlir::DenseIntElementsAttr elements,
|
||||
int64 v) {
|
||||
::mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
|
||||
return std::binary_search(
|
||||
elements.begin(), elements.end(), value,
|
||||
[](const ::mlir::APInt& x, const ::mlir::APInt& y) { return x.slt(y); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
|
||||
@ -444,7 +453,7 @@ llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op,
|
||||
}
|
||||
|
||||
// Check the size of the internal result tensors
|
||||
if (auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op)) {
|
||||
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
|
||||
auto result = fusion.region().walk([&](mlir::Operation* op) {
|
||||
for (mlir::Value result : op->getResults()) {
|
||||
if (!hlo_shape_in_range(result)) {
|
||||
@ -800,12 +809,30 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
GetGeneratorForOperandIrArrays(fusion),
|
||||
&scatter_elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
|
||||
TF_RETURN_IF_ERROR(EmitScatter(
|
||||
thunks.back().get(), root,
|
||||
/*scatter_indices_gen=*/
|
||||
scatter_fused_emitter.GetGenerator(root->operand(1)),
|
||||
/*updates_gen=*/
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2))));
|
||||
CHECK_EQ(root->parent()->FusionInstruction(), fusion);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const auto dim_numbers,
|
||||
lhlo_scratch_emitter_.GetScatterDimensionNumbers(root));
|
||||
|
||||
ScatterDescriptor desc;
|
||||
desc.name = IrName(root);
|
||||
desc.operand_shape = root->operand(0)->shape();
|
||||
desc.scatter_indices_shape = root->operand(1)->shape();
|
||||
desc.updates_shape = root->operand(2)->shape();
|
||||
desc.dim_numbers = dim_numbers;
|
||||
desc.unique_indices = root->unique_indices();
|
||||
desc.update_computation = root->called_computations()[0];
|
||||
desc.output = GetIrArray(*fusion, *fusion);
|
||||
desc.scatter_indices_gen =
|
||||
scatter_fused_emitter.GetGenerator(root->operand(1));
|
||||
desc.updates_gen =
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2));
|
||||
desc.get_index_type = [&](int64 launch_size) {
|
||||
return GetIndexTypeForKernel(root, launch_size, &b_);
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(EmitScatter(desc, thunks.back().get()));
|
||||
}
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(fusion), std::move(thunks)));
|
||||
@ -1233,61 +1260,118 @@ Status IrEmitterUnnested::HandleRngGetAndUpdateState(
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
const HloInstruction* operand = scatter->operand(0);
|
||||
const HloInstruction* scatter_indices = scatter->operand(1);
|
||||
const HloInstruction* updates = scatter->operand(2);
|
||||
MlirEmitterInput result;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto scatter_op,
|
||||
lhlo_scratch_emitter_.EmitScatterOp(scatter));
|
||||
result.op = scatter_op;
|
||||
result.thunk_info = GetThunkInfo(scatter);
|
||||
return EmitScatterFromMlir(result);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitScatterFromMlir(MlirEmitterInput mlir_input) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
absl::Span<const BufferAllocation> allocations(
|
||||
ir_emitter_context_->buffer_assignment().Allocations());
|
||||
|
||||
::mlir::lmhlo::ScatterOp scatter_op =
|
||||
::mlir::cast<::mlir::lmhlo::ScatterOp>(mlir_input.op);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto operand_buffer,
|
||||
GetAllocationSliceForMlir(scatter_op.operand(), allocations));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto output_buffer,
|
||||
GetAllocationSliceForMlir(scatter_op.output(), allocations));
|
||||
|
||||
// Copy the operand into the output if it's not the same buffer already.
|
||||
auto operand_buffer = GetAllocationSlice(*operand);
|
||||
auto destination_buffer = GetAllocationSlice(*scatter);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
if (operand_buffer != output_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape())));
|
||||
/*destination_buffer=*/output_buffer,
|
||||
/*mem_size=*/
|
||||
ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType()))));
|
||||
}
|
||||
|
||||
thunks.push_back(
|
||||
BuildKernelThunk(scatter,
|
||||
/*implements_whole_instruction=*/thunks.empty()));
|
||||
// Create MLIR buffer slice info for all operands except the first one
|
||||
// (`operand`). The code generated for scatter below assumes that the input
|
||||
// operand is already copied into the output, so does not use it in codegen.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<MlirBufferSlice> operand_slices,
|
||||
GetMlirBufferSlices(scatter_op, scatter_op.getOperands().drop_front(),
|
||||
allocations));
|
||||
|
||||
std::string name = mlir::GetNameFromLoc(scatter_op.getLoc());
|
||||
std::vector<llvm_ir::IrArray> ir_arrays;
|
||||
thunks.push_back(BuildKernelThunkForMlir(name, mlir_input.thunk_info,
|
||||
operand_slices, &ir_arrays));
|
||||
CHECK_EQ(ir_arrays.size(), 3);
|
||||
const IrArray& scatter_indices = ir_arrays[0];
|
||||
const IrArray& updates = ir_arrays[1];
|
||||
const IrArray& output = ir_arrays[2];
|
||||
|
||||
auto get_index_type = [&](int64 launch_size) {
|
||||
return GetIndexTypeForKernelFromMlir(scatter_op, launch_size, &b_);
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(EmitScatter(
|
||||
thunks.back().get(), scatter,
|
||||
thunks.back().get(), scatter_op, output,
|
||||
/*scatter_indices_gen=*/
|
||||
[=](const IrArray::Index& index) {
|
||||
return GetIrArray(*scatter_indices, *scatter)
|
||||
.EmitReadArrayElement(index, &b_, "scatter_index");
|
||||
[&](const IrArray::Index& index) {
|
||||
return scatter_indices.EmitReadArrayElement(index, &b_,
|
||||
"scatter_index");
|
||||
},
|
||||
/*updates_gen=*/
|
||||
[=](const IrArray::Index& index) {
|
||||
return GetIrArray(*updates, *scatter)
|
||||
.EmitReadArrayElement(index, &b_, "update");
|
||||
}));
|
||||
[&](const IrArray::Index& index) {
|
||||
return updates.EmitReadArrayElement(index, &b_, "update");
|
||||
},
|
||||
/* get_index_type=*/
|
||||
get_index_type));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(scatter), std::move(thunks)));
|
||||
mlir_input.thunk_info, std::move(thunks)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitScatter(
|
||||
Thunk* thunk, HloInstruction* scatter,
|
||||
Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
|
||||
const llvm_ir::IrArray& output,
|
||||
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
||||
const llvm_ir::ElementGenerator& updates_gen) {
|
||||
const HloInstruction* operand = scatter->operand(0);
|
||||
const HloInstruction* scatter_indices = scatter->operand(1);
|
||||
const HloInstruction* updates = scatter->operand(2);
|
||||
const ScatterDimensionNumbers& dim_numbers =
|
||||
scatter->scatter_dimension_numbers();
|
||||
CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
|
||||
const llvm_ir::ElementGenerator& updates_gen,
|
||||
std::function<llvm::Type*(int64)> get_index_type) {
|
||||
const Shape operand_shape = TypeToShape(scatter.operand().getType());
|
||||
CHECK(
|
||||
ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const HloComputation* update_computation,
|
||||
GetOrCreateSubComputationFromRegion(&scatter.update_computation()));
|
||||
|
||||
ScatterDescriptor desc;
|
||||
desc.name = mlir::GetNameFromLoc(scatter.getLoc());
|
||||
desc.operand_shape = operand_shape;
|
||||
desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType());
|
||||
desc.updates_shape = TypeToShape(scatter.updates().getType());
|
||||
desc.dim_numbers = scatter.scatter_dimension_numbers();
|
||||
desc.unique_indices = scatter.unique_indices();
|
||||
desc.update_computation = update_computation;
|
||||
desc.output = output;
|
||||
desc.scatter_indices_gen = scatter_indices_gen;
|
||||
desc.updates_gen = updates_gen;
|
||||
desc.get_index_type = get_index_type;
|
||||
return EmitScatter(desc, thunk);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc,
|
||||
Thunk* thunk) {
|
||||
auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
|
||||
std::vector<llvm::Value*> raw_window_multidim;
|
||||
std::vector<llvm::Value*> input_scatter_multidim;
|
||||
@ -1297,22 +1381,25 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
for (int64 i = 0, e = index.size(); i != e; ++i) {
|
||||
// For window indices also remember the window size, this comes in handy
|
||||
// later.
|
||||
if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
|
||||
if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(),
|
||||
i)) {
|
||||
raw_window_multidim.push_back(index[i]);
|
||||
raw_window_bounds.push_back(updates->shape().dimensions(i));
|
||||
raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
|
||||
} else {
|
||||
input_scatter_multidim.push_back(index[i]);
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(raw_window_multidim.size(),
|
||||
dim_numbers.update_window_dims_size());
|
||||
desc.dim_numbers.update_window_dims().size());
|
||||
|
||||
// Apply inserted_window_dims to the window dimensions.
|
||||
int64 raw_window_multidim_idx = 0;
|
||||
std::vector<llvm::Value*> input_window_multidim;
|
||||
std::vector<int64> input_window_bounds;
|
||||
for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
|
||||
if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
|
||||
|
||||
for (int64 i = 0, e = desc.operand_shape.rank(); i != e; ++i) {
|
||||
if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(),
|
||||
i)) {
|
||||
input_window_bounds.push_back(1); // Trivial dimension.
|
||||
input_window_multidim.push_back(index.GetConstantWithIndexType(0));
|
||||
} else {
|
||||
@ -1323,14 +1410,15 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
++raw_window_multidim_idx;
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
|
||||
DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
|
||||
|
||||
// Insert a 1 dimension at the end if index_vector_dim requests one.
|
||||
Shape scatter_indices_shape = scatter_indices->shape();
|
||||
if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
|
||||
scatter_indices_shape.add_dimensions(1);
|
||||
scatter_indices_shape.mutable_layout()->add_minor_to_major(
|
||||
dim_numbers.index_vector_dim());
|
||||
Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
|
||||
if (desc.dim_numbers.index_vector_dim().getInt() ==
|
||||
desc.scatter_indices_shape.rank()) {
|
||||
scatter_indices_shape_fixed.add_dimensions(1);
|
||||
scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
|
||||
desc.dim_numbers.index_vector_dim().getInt());
|
||||
}
|
||||
|
||||
// Now load the indices corresponding to the current window from
|
||||
@ -1338,23 +1426,27 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
std::vector<llvm::Value*> raw_scatter_index_multidim =
|
||||
input_scatter_multidim;
|
||||
raw_scatter_index_multidim.insert(
|
||||
raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(),
|
||||
raw_scatter_index_multidim.begin() +
|
||||
desc.dim_numbers.index_vector_dim().getInt(),
|
||||
nullptr);
|
||||
llvm::Value* is_in_bounds = b_.getTrue();
|
||||
for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
|
||||
for (int64 i = 0,
|
||||
e = desc.dim_numbers.scatter_dims_to_operand_dims().size();
|
||||
i != e; ++i) {
|
||||
// Our index is stored along index_vector_dim, insert that into the lookup
|
||||
// index into scatter_indices.
|
||||
raw_scatter_index_multidim[dim_numbers.index_vector_dim()] =
|
||||
raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] =
|
||||
index.GetConstantWithIndexType(i);
|
||||
llvm_ir::IrArray::Index raw_scatter_index_index(
|
||||
raw_scatter_index_multidim, scatter_indices_shape, index.GetType());
|
||||
raw_scatter_index_multidim, scatter_indices_shape_fixed,
|
||||
index.GetType());
|
||||
|
||||
int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
|
||||
int64 operand_dim =
|
||||
desc.dim_numbers.scatter_dims_to_operand_dims().getValue<int64>(i);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Value* const loaded_scatter_index,
|
||||
scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
|
||||
scatter_indices_shape, scatter_indices->shape(), &b_)));
|
||||
desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
|
||||
scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
|
||||
// And add the index to our window index. This yields the output index.
|
||||
llvm::Value* casted_scatter_index =
|
||||
IntCast(loaded_scatter_index, index.GetType(),
|
||||
@ -1364,7 +1456,7 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
input_window_multidim[operand_dim] = dim_offset;
|
||||
|
||||
// Also do the bounds check now.
|
||||
int64 max_index = operand->shape().dimensions(operand_dim) -
|
||||
int64 max_index = desc.operand_shape.dimensions(operand_dim) -
|
||||
input_window_bounds[operand_dim] + 1;
|
||||
// is_in_bounds = index >= 0 && index < dim_size-window_size+1
|
||||
// --> index u< dim_size-window_size+1
|
||||
@ -1378,25 +1470,23 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
|
||||
// All done, now just read from the calculated input from the window, and do
|
||||
// an atomic store to the calculated location in the output.
|
||||
HloInstruction* output_hlo =
|
||||
scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
|
||||
llvm_ir::IrArray::Index input_window_index(
|
||||
input_window_multidim, output_hlo->shape(), index.GetType());
|
||||
input_window_multidim, desc.output.GetShape(), index.GetType());
|
||||
llvm::Value* output_address =
|
||||
GetIrArray(*output_hlo, *output_hlo)
|
||||
.EmitArrayElementAddress(input_window_index, &b_);
|
||||
desc.output.EmitArrayElementAddress(input_window_index, &b_);
|
||||
llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(updates->shape().element_type(),
|
||||
llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
|
||||
module_),
|
||||
"input_address", &b_);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
|
||||
desc.updates_gen(index));
|
||||
Store(input_ir_value, input_address);
|
||||
|
||||
if (!scatter->unique_indices()) {
|
||||
if (!desc.unique_indices) {
|
||||
return EmitAtomicOperationForNestedComputation(
|
||||
*scatter->to_apply(), output_address, input_address);
|
||||
*desc.update_computation, output_address, input_address);
|
||||
} else {
|
||||
return EmitCallToNestedComputation(*scatter->to_apply(),
|
||||
return EmitCallToNestedComputation(*desc.update_computation,
|
||||
{output_address, input_address},
|
||||
output_address);
|
||||
}
|
||||
@ -1406,15 +1496,14 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
// also do one kernel per window instead if bounds checks turn out to be a
|
||||
// bottleneck.
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
updates->shape(), ir_emitter_context_->gpu_device_info());
|
||||
desc.updates_shape, ir_emitter_context_->gpu_device_info());
|
||||
UpdateLaunchDimensions(launch_dimensions, thunk,
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
|
||||
return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
|
||||
launch_dimensions, &b_)
|
||||
.EmitLoop(IrName(scatter),
|
||||
GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
|
||||
&b_));
|
||||
.EmitLoop(desc.name,
|
||||
desc.get_index_type(launch_dimensions.launch_bound()));
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||
|
||||
namespace xla {
|
||||
@ -158,6 +159,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleRng(HloInstruction* random) override;
|
||||
Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
|
||||
Status HandleScatter(HloInstruction* scatter) override;
|
||||
Status EmitScatterFromMlir(MlirEmitterInput mlir_input);
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
|
||||
@ -407,16 +409,38 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
const llvm_ir::IrArray::Index& slice_input_index);
|
||||
|
||||
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
|
||||
// the process. `scatter` may be fused, scatter indices are taken from
|
||||
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
|
||||
// expected to have the operand values in it already. If unique_indices
|
||||
// is false, we will use an atomic update. Using true for unique_indices
|
||||
// behaves properly only when it is guaranteed that the indices to be
|
||||
// updated do not overlap. The caller is responsible for ensuring this is
|
||||
// the case.
|
||||
Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
|
||||
// the process. Scatter indices are taken from `scatter_indices_gen`, updates
|
||||
// from `updates_gen`. The output buffer is expected to have the operand
|
||||
// values in it already. If unique_indices is false, we will use an atomic
|
||||
// update. Using true for unique_indices behaves properly only when it is
|
||||
// guaranteed that the indices to be updated do not overlap. The caller is
|
||||
// responsible for ensuring this is the case.
|
||||
Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
|
||||
const llvm_ir::IrArray& output,
|
||||
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
||||
const llvm_ir::ElementGenerator& updates_gen);
|
||||
const llvm_ir::ElementGenerator& updates_gen,
|
||||
std::function<llvm::Type*(int64)> get_index_type);
|
||||
|
||||
// Structure describing a scatter operation for IR emission.
|
||||
// TODO(jurahul): Migrate element generators to use MLIR.
|
||||
// Migrate update_computation to be an MLIR Region.
|
||||
struct ScatterDescriptor {
|
||||
std::string name;
|
||||
Shape operand_shape;
|
||||
Shape scatter_indices_shape;
|
||||
Shape updates_shape;
|
||||
mlir::mhlo::ScatterDimensionNumbers dim_numbers;
|
||||
bool unique_indices;
|
||||
const HloComputation* update_computation;
|
||||
llvm_ir::IrArray output;
|
||||
llvm_ir::ElementGenerator scatter_indices_gen;
|
||||
llvm_ir::ElementGenerator updates_gen;
|
||||
std::function<llvm::Type*(int64)> get_index_type;
|
||||
};
|
||||
|
||||
// Emits code for an in-place scatter using the provided scatter operation
|
||||
// description.
|
||||
Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk);
|
||||
|
||||
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
|
||||
// for the hlo instruction.
|
||||
|
@ -225,7 +225,7 @@ void EmitBitcodeToFile(const llvm::Module& module, absl::string_view filename) {
|
||||
// for the NVPTX target.
|
||||
string EmitModuleToPTX(llvm::Module* module,
|
||||
llvm::TargetMachine* target_machine) {
|
||||
std::string ptx; // need a std::string instead of a ::string.
|
||||
std::string ptx;
|
||||
{
|
||||
llvm::raw_string_ostream stream(ptx);
|
||||
llvm::buffer_ostream pstream(stream);
|
||||
|
@ -3,14 +3,12 @@
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
||||
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [3 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
|
||||
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [2 x i32]*
|
||||
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
|
||||
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [2 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_12:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_13:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
||||
// CHECK: %[[VAL_14:.*]] = mul nuw nsw i32 %[[VAL_12]], 6
|
||||
@ -75,14 +73,12 @@ ENTRY main {
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||
// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32*
|
||||
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds i8, i8* %[[VAL_41:.*]], i64 0
|
||||
// CHECK: %[[VAL_42:.*]] = bitcast i8* %[[VAL_40]] to i32*
|
||||
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds i8, i8* %[[VAL_44:.*]], i64 0
|
||||
// CHECK: %[[VAL_45:.*]] = bitcast i8* %[[VAL_43]] to [0 x i32]*
|
||||
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds i8, i8* %[[VAL_47:.*]], i64 0
|
||||
// CHECK: %[[VAL_48:.*]] = bitcast i8* %[[VAL_46]] to i32*
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||
// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32*
|
||||
// CHECK: %[[VAL_49:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
|
||||
// CHECK: %[[VAL_51:.*]] = mul nuw nsw i32 %[[VAL_49]], 1
|
||||
@ -135,14 +131,12 @@ ENTRY main {
|
||||
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0
|
||||
// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_68:.*]] = getelementptr inbounds i8, i8* %[[VAL_69:.*]], i64 0
|
||||
// CHECK: %[[VAL_70:.*]] = bitcast i8* %[[VAL_68]] to [3 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_71:.*]] = getelementptr inbounds i8, i8* %[[VAL_72:.*]], i64 0
|
||||
// CHECK: %[[VAL_73:.*]] = bitcast i8* %[[VAL_71]] to [2 x i32]*
|
||||
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds i8, i8* %[[VAL_75:.*]], i64 0
|
||||
// CHECK: %[[VAL_76:.*]] = bitcast i8* %[[VAL_74]] to [2 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0
|
||||
// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]*
|
||||
// CHECK: %[[VAL_77:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
||||
// CHECK: %[[VAL_79:.*]] = mul nuw nsw i32 %[[VAL_77]], 6
|
||||
@ -180,7 +174,7 @@ ENTRY main {
|
||||
// CHECK: atomic_op_loop_body: ; preds = %[[VAL_104]], %[[VAL_95]]
|
||||
// CHECK: %[[VAL_105:.*]] = load i32, i32* %[[VAL_64]], align 4
|
||||
// CHECK: store i32 %[[VAL_105]], i32* %[[VAL_63]], align 4
|
||||
// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]])
|
||||
// CHECK: call void @{{.+}}(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]])
|
||||
// CHECK: %[[VAL_106:.*]] = load i32, i32* %[[VAL_63]], align 4
|
||||
// CHECK: %[[VAL_107:.*]] = cmpxchg i32* %[[VAL_97]], i32 %[[VAL_105]], i32 %[[VAL_106]] seq_cst seq_cst
|
||||
// CHECK: %[[VAL_108:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 0
|
||||
@ -219,14 +213,12 @@ ENTRY main {
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||
// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]*
|
||||
// CHECK: %[[VAL_121:.*]] = getelementptr inbounds i8, i8* %[[VAL_122:.*]], i64 0
|
||||
// CHECK: %[[VAL_123:.*]] = bitcast i8* %[[VAL_121]] to [4 x i32]*
|
||||
// CHECK: %[[VAL_124:.*]] = getelementptr inbounds i8, i8* %[[VAL_125:.*]], i64 0
|
||||
// CHECK: %[[VAL_126:.*]] = bitcast i8* %[[VAL_124]] to i32*
|
||||
// CHECK: %[[VAL_127:.*]] = getelementptr inbounds i8, i8* %[[VAL_128:.*]], i64 0
|
||||
// CHECK: %[[VAL_129:.*]] = bitcast i8* %[[VAL_127]] to i32*
|
||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||
// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]*
|
||||
// CHECK: %[[VAL_130:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_131:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
|
||||
// CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_130]], 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user