[XLA/GPU] Factor out the logic of MLIR op -> kernel param slices to share with all emitters.
PiperOrigin-RevId: 335962025 Change-Id: Ie7d462bcf533d62bfdc5ddfb91666f20e26dd03b
This commit is contained in:
parent
84e7820964
commit
0406aa2fcc
@ -202,6 +202,31 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
|
||||
"StaticMemRefCastOp(ViewOp(arg))");
|
||||
}
|
||||
|
||||
StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
|
||||
mlir::Operation* op, mlir::OperandRange operands,
|
||||
absl::Span<const BufferAllocation> allocations) {
|
||||
const auto buffer_is_written = [op](mlir::Value operand) {
|
||||
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
|
||||
mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
|
||||
effects);
|
||||
return absl::c_any_of(
|
||||
effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
|
||||
return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
|
||||
});
|
||||
};
|
||||
|
||||
std::vector<MlirBufferSlice> slices;
|
||||
for (mlir::Value operand : operands) {
|
||||
slices.emplace_back();
|
||||
auto& slice = slices.back();
|
||||
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
|
||||
GetAllocationSliceForMlir(operand, allocations));
|
||||
slice.written = buffer_is_written(operand);
|
||||
slice.shape = TypeToShape(operand.getType());
|
||||
}
|
||||
return slices;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
|
||||
@ -1371,47 +1396,30 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
return EmitSortFromMlir(result);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
|
||||
Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
|
||||
absl::Span<const BufferAllocation> allocations(
|
||||
ir_emitter_context_->buffer_assignment().Allocations());
|
||||
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(input.op);
|
||||
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(mlir_input.op);
|
||||
std::string name = mlir::GetNameFromLoc(sort_op.getLoc());
|
||||
|
||||
int operand_count = sort_op.operands().size();
|
||||
std::vector<xla::Shape> operand_shapes(operand_count);
|
||||
std::vector<MlirBufferSlice> slices;
|
||||
std::vector<xla::Shape> output_shapes(sort_op.output().size());
|
||||
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
operand_shapes[i] = TypeToShape(sort_op.operands()[i].getType());
|
||||
}
|
||||
|
||||
// Craft n + 1 slices, where the first n are output parameters, and the last
|
||||
// is the on-device tuple storage. We don't need n operands because sorting
|
||||
// kernels are always in-place.
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
output_shapes[i] = TypeToShape(sort_op.output()[i].getType());
|
||||
MlirBufferSlice slice;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
slice.buffer_slice,
|
||||
GetAllocationSliceForMlir(sort_op.output()[i], allocations));
|
||||
slice.written = true;
|
||||
slice.shape = operand_shapes[i];
|
||||
slices.push_back(slice);
|
||||
}
|
||||
slices.push_back(input.extra_slice);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<MlirBufferSlice> operands,
|
||||
GetMlirBufferSlices(sort_op, sort_op.operands(), allocations));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<MlirBufferSlice> outputs,
|
||||
GetMlirBufferSlices(sort_op, sort_op.output(), allocations));
|
||||
outputs.push_back(mlir_input.extra_slice);
|
||||
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
Shape keys_shape = operand_shapes[0];
|
||||
Shape keys_shape = operands[0].shape;
|
||||
int64 dimension_to_sort = sort_op.dimension();
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
for (int64 i = 0; i < operands.size(); ++i) {
|
||||
// We assume that the layout of all involved operands and outputs is the
|
||||
// same.
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i]));
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, operands[i].shape));
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i]));
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, outputs[i].shape));
|
||||
|
||||
// If possible, we share buffers. If that is not possible, we need to copy
|
||||
// the values, because the emitter does the sorting in-place.
|
||||
@ -1429,7 +1437,7 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/source_address,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i])));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operands[i].shape)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1499,10 +1507,10 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
|
||||
// we have not enough threads, or not enough shared memory. Also it does not
|
||||
// give a speedup if the tile size is < 128.
|
||||
int64 total_shared_memory_needed = 0;
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
for (int64 i = 0; i < operands.size(); ++i) {
|
||||
total_shared_memory_needed +=
|
||||
kTileSize *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type());
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(operands[i].shape.element_type());
|
||||
}
|
||||
bool no_tiling =
|
||||
kTileSize < 128 ||
|
||||
@ -1533,15 +1541,15 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
|
||||
absl::StrAppendFormat(out, "0x%x", xor_mask);
|
||||
}));
|
||||
thunks.push_back(
|
||||
BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), slices, &ir_arrays));
|
||||
BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), outputs, &ir_arrays));
|
||||
LaunchDimensions launch_dimensions = xor_masks.size() > 1
|
||||
? tiled_launch_dimensions
|
||||
: standard_launch_dimensions;
|
||||
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
std::vector<IrArray> values_arrays;
|
||||
values_arrays.reserve(operand_count);
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
values_arrays.reserve(operands.size());
|
||||
for (int64 i = 0; i < operands.size(); ++i) {
|
||||
values_arrays.push_back(ir_arrays[i]);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1583,14 +1591,14 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
|
||||
VLOG(2) << absl::StreamFormat(
|
||||
"%s requires %d thunks (including any D2D copies)", name, thunks.size());
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(input.thunk_info, std::move(thunks)));
|
||||
if (operand_count > 1) {
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
mlir_input.thunk_info, std::move(thunks)));
|
||||
if (operands.size() > 1) {
|
||||
// Emit the tuple as part of the last stage of sorting.
|
||||
// We are currently in the block sorted.in_bounds.after.
|
||||
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
|
||||
llvm_ir::EmitTuple(
|
||||
ir_arrays[operand_count],
|
||||
ir_arrays.back(),
|
||||
absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -160,7 +160,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleScatter(HloInstruction* scatter) override;
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status EmitSortFromMlir(MlirEmitterInput input);
|
||||
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
|
||||
Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
|
Loading…
Reference in New Issue
Block a user