[XLA/GPU] Factor out the logic of MLIR op -> kernel param slices to share with all emitters.

PiperOrigin-RevId: 335962025
Change-Id: Ie7d462bcf533d62bfdc5ddfb91666f20e26dd03b
This commit is contained in:
Tim Shen 2020-10-07 15:09:35 -07:00 committed by TensorFlower Gardener
parent 84e7820964
commit 0406aa2fcc
2 changed files with 49 additions and 41 deletions

View File

@ -202,6 +202,31 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
"StaticMemRefCastOp(ViewOp(arg))");
}
StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
mlir::Operation* op, mlir::OperandRange operands,
absl::Span<const BufferAllocation> allocations) {
const auto buffer_is_written = [op](mlir::Value operand) {
llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
effects);
return absl::c_any_of(
effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
});
};
std::vector<MlirBufferSlice> slices;
for (mlir::Value operand : operands) {
slices.emplace_back();
auto& slice = slices.back();
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
GetAllocationSliceForMlir(operand, allocations));
slice.written = buffer_is_written(operand);
slice.shape = TypeToShape(operand.getType());
}
return slices;
}
} // namespace
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
@ -1371,47 +1396,30 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
return EmitSortFromMlir(result);
}
Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
absl::Span<const BufferAllocation> allocations(
ir_emitter_context_->buffer_assignment().Allocations());
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(input.op);
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(mlir_input.op);
std::string name = mlir::GetNameFromLoc(sort_op.getLoc());
int operand_count = sort_op.operands().size();
std::vector<xla::Shape> operand_shapes(operand_count);
std::vector<MlirBufferSlice> slices;
std::vector<xla::Shape> output_shapes(sort_op.output().size());
for (int i = 0; i < operand_count; i++) {
operand_shapes[i] = TypeToShape(sort_op.operands()[i].getType());
}
// Craft n + 1 slices, where the first n are output parameters, and the last
// is the on-device tuple storage. We don't need n operands because sorting
// kernels are always in-place.
for (int i = 0; i < operand_count; i++) {
output_shapes[i] = TypeToShape(sort_op.output()[i].getType());
MlirBufferSlice slice;
TF_ASSIGN_OR_RETURN(
slice.buffer_slice,
GetAllocationSliceForMlir(sort_op.output()[i], allocations));
slice.written = true;
slice.shape = operand_shapes[i];
slices.push_back(slice);
}
slices.push_back(input.extra_slice);
TF_ASSIGN_OR_RETURN(
std::vector<MlirBufferSlice> operands,
GetMlirBufferSlices(sort_op, sort_op.operands(), allocations));
TF_ASSIGN_OR_RETURN(
std::vector<MlirBufferSlice> outputs,
GetMlirBufferSlices(sort_op, sort_op.output(), allocations));
outputs.push_back(mlir_input.extra_slice);
std::vector<std::unique_ptr<Thunk>> thunks;
Shape keys_shape = operand_shapes[0];
Shape keys_shape = operands[0].shape;
int64 dimension_to_sort = sort_op.dimension();
for (int64 i = 0; i < operand_count; ++i) {
for (int64 i = 0; i < operands.size(); ++i) {
// We assume that the layout of all involved operands and outputs is the
// same.
TF_RET_CHECK(
LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i]));
LayoutUtil::LayoutsInShapesEqual(keys_shape, operands[i].shape));
TF_RET_CHECK(
LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i]));
LayoutUtil::LayoutsInShapesEqual(keys_shape, outputs[i].shape));
// If possible, we share buffers. If that is not possible, we need to copy
// the values, because the emitter does the sorting in-place.
@ -1429,7 +1437,7 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
Thunk::ThunkInfo(),
/*source_address=*/source_address,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i])));
/*mem_size=*/ShapeUtil::ByteSizeOf(operands[i].shape)));
}
}
@ -1499,10 +1507,10 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
// we have not enough threads, or not enough shared memory. Also it does not
// give a speedup if the tile size is < 128.
int64 total_shared_memory_needed = 0;
for (int64 i = 0; i < operand_count; ++i) {
for (int64 i = 0; i < operands.size(); ++i) {
total_shared_memory_needed +=
kTileSize *
ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type());
ShapeUtil::ByteSizeOfPrimitiveType(operands[i].shape.element_type());
}
bool no_tiling =
kTileSize < 128 ||
@ -1533,15 +1541,15 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
absl::StrAppendFormat(out, "0x%x", xor_mask);
}));
thunks.push_back(
BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), slices, &ir_arrays));
BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), outputs, &ir_arrays));
LaunchDimensions launch_dimensions = xor_masks.size() > 1
? tiled_launch_dimensions
: standard_launch_dimensions;
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
ir_emitter_context_->llvm_module());
std::vector<IrArray> values_arrays;
values_arrays.reserve(operand_count);
for (int64 i = 0; i < operand_count; ++i) {
values_arrays.reserve(operands.size());
for (int64 i = 0; i < operands.size(); ++i) {
values_arrays.push_back(ir_arrays[i]);
}
TF_ASSIGN_OR_RETURN(
@ -1583,14 +1591,14 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
VLOG(2) << absl::StreamFormat(
"%s requires %d thunks (including any D2D copies)", name, thunks.size());
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(input.thunk_info, std::move(thunks)));
if (operand_count > 1) {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
mlir_input.thunk_info, std::move(thunks)));
if (operands.size() > 1) {
// Emit the tuple as part of the last stage of sorting.
// We are currently in the block sorted.in_bounds.after.
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
llvm_ir::EmitTuple(
ir_arrays[operand_count],
ir_arrays.back(),
absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
}
return Status::OK();

View File

@ -160,7 +160,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleSort(HloInstruction* sort) override;
Status EmitSortFromMlir(MlirEmitterInput input);
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleAllReduce(HloInstruction* crs) override;