Roll-forward with fix:

[XLA/GPU] Migrate fused Slice to take LMHLO.

PiperOrigin-RevId: 355491691
Change-Id: I45147b974f4b5e1cbf886945c1078a98e088beba
This commit is contained in:
Tim Shen 2021-02-03 15:08:27 -08:00 committed by TensorFlower Gardener
parent 2ddb6e97f6
commit 2fb379a67c
5 changed files with 93 additions and 76 deletions

View File

@ -369,29 +369,29 @@ bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce) {
return reduction_dimensions.dimensions[1] >= kWarpSize;
}
bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
bool verify_no_strides) {
if (!unnested_hlo.IsInputFusion()) {
auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
if (!fusion) {
return false;
}
auto is_non_strided = [](const std::vector<int64>& strides) -> bool {
return absl::c_all_of(strides, [](int stride) { return stride == 1; });
auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
return absl::c_all_of(
strides, [](const llvm::APInt& stride) { return stride == 1; });
};
const HloInstruction* root = unnested_hlo.fused_expression_root();
if (root->opcode() == HloOpcode::kSlice) {
return !verify_no_strides || is_non_strided(root->slice_strides());
for (mlir::Value value : fusion.getFusionResults()) {
auto slice =
mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
if (!slice) {
return false;
}
if (verify_no_strides && !is_non_strided(slice.strides())) {
return false;
}
}
if (root->opcode() != HloOpcode::kTuple) {
return false;
}
return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kSlice &&
(!verify_no_strides || is_non_strided(instr->slice_strides()));
});
return true;
}
ReductionDimensions GetReductionKindAndContiguousComponents(

View File

@ -168,8 +168,8 @@ bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce);
// Returns whether unnested_hlo is an input fusion whose root is either a slice
// or a tuple of slices. If verify_no_strides is true, returns false unless all
// ROOT slices have no strides.
bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
bool verify_no_strides = false);
bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
bool verify_no_strides);
struct ReductionDimensions {
// Indicates whether the reduction is a row reduction or a column reduction.

View File

@ -490,34 +490,36 @@ llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op,
// slices are the same and the slices are non-strided. Otherwise, returns
// FailedPrecondition.
StatusOr<Shape> GetConsistentInputShapeForRootSlices(
const HloInstruction& fusion) {
mlir::lmhlo::FusionOp fusion) {
if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
return FailedPrecondition(
"Unsupported root for slice input fusion. "
"Only non-strided slices are supported.");
}
const HloInstruction& root = *fusion.fused_expression_root();
if (root.opcode() == HloOpcode::kSlice) {
return root.operands()[0]->shape();
}
CHECK_EQ(root.opcode(), HloOpcode::kTuple);
const Shape& first_slice_operand_shape =
root.operands()[0]->operands()[0]->shape();
for (size_t i = 1; i < root.operands().size(); ++i) {
const HloInstruction* slice = root.operands()[i];
const Shape& operand_shape = slice->operands()[0]->shape();
if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
operand_shape)) {
return FailedPrecondition(
"Fused slices do not have the same input shape, fused computation = "
"%s.",
root.parent()->name());
absl::optional<Shape> first_slice_operand_shape;
for (mlir::Value result : fusion.getFusionResults()) {
auto slice =
mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(result.getDefiningOp());
if (!slice) {
return FailedPrecondition("Expected a slice op");
}
if (first_slice_operand_shape.has_value()) {
Shape operand_shape = TypeToShape(slice.operand().getType());
if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape,
operand_shape)) {
return FailedPrecondition(
"Fused slices do not have the same input shape, instruction is %s",
MlirToString(fusion));
}
} else {
first_slice_operand_shape = TypeToShape(slice.operand().getType());
}
}
return first_slice_operand_shape;
if (!first_slice_operand_shape.has_value()) {
return InvalidArgument("Fusion has no roots");
}
return *first_slice_operand_shape;
}
} // namespace
@ -1842,8 +1844,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// In the case of root tuple, it can be either reduce or slice input
// fusion.
case HloOpcode::kTuple: {
if (IsInputFusibleSlices(*fusion)) {
return EmitInputFusibleNonStridedSlices(fusion);
if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) {
return EmitInputFusibleNonStridedSlices(mlir_input);
}
CHECK_GE(mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op)
@ -1864,7 +1866,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return EmitReductionFromOrToContiguousDimensions(mlir_input);
}
case HloOpcode::kSlice: {
return EmitInputFusibleNonStridedSlices(fusion);
return EmitInputFusibleNonStridedSlices(mlir_input);
}
default:
LOG(FATAL) << "Bad opcode for input fusion: "
@ -5636,11 +5638,16 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
// Write to output of slice1
// }
//
void IrEmitterUnnested::EmitElementForInputFusibleSlices(
HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) {
VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString();
Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
mlir::lmhlo::FusionOp fusion, absl::Span<const llvm_ir::IrArray> ir_arrays,
const llvm_ir::IrArray::Index& index) {
VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion);
HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root();
TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
GetOrCreateSubComputationFromRegion(&fusion.region(),
/*is_fusion=*/true));
HloInstruction* slice_or_tuple = fused_computation->root_instruction();
auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
@ -5654,7 +5661,13 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(&elem_emitter);
BindFusionArguments(unnested_hlo, &fused_emitter);
for (int i = 0; i < fused_computation->num_parameters(); i++) {
fused_emitter.BindGenerator(
fused_computation->parameter_instruction(i),
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(index, &b_);
});
}
for (const HloInstruction* slice : slice_instructions) {
auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
input_ir_values.push_back(input_generator(index).ValueOrDie());
@ -5689,11 +5702,8 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
Sub(src_multidim[dim],
index.GetConstantWithIndexType(slice->slice_starts(dim)));
}
ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice)
? ShapeIndex()
: ShapeIndex({i});
llvm_ir::IrArray src_ir_array =
GetIrArray(*unnested_hlo, *unnested_hlo, shape_index);
ir_arrays[fused_computation->num_parameters() + i];
IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
index.GetType());
src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
@ -5702,16 +5712,23 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
}
return Status::OK();
}
Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
HloInstruction* unnested_hlo) {
MlirEmitterInput mlir_input) {
auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
constexpr int unroll_factor = 1;
std::unique_ptr<KernelThunk> kernel_thunk =
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/true);
std::vector<llvm_ir::IrArray> ir_arrays;
TF_ASSIGN_OR_RETURN(
auto kernel_thunk,
BuildKernelThunkForMlir(fusion, mlir_input.thunk_info,
mlir_input.extra_slice, &ir_arrays));
TF_ASSIGN_OR_RETURN(Shape element_shape,
GetConsistentInputShapeForRootSlices(*unnested_hlo));
GetConsistentInputShapeForRootSlices(fusion));
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
@ -5720,13 +5737,12 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
Status emit_status =
ParallelLoopEmitter(
[&](const llvm_ir::IrArray::Index index) -> Status {
EmitElementForInputFusibleSlices(unnested_hlo, index);
return Status::OK();
return EmitElementForInputFusibleSlices(fusion, ir_arrays, index);
},
element_shape, launch_dimensions, &b_)
.EmitLoop(IrName(unnested_hlo),
GetIndexTypeForKernel(
unnested_hlo, launch_dimensions.launch_bound(), &b_));
.EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
GetIndexTypeForKernelFromMlir(
fusion, launch_dimensions.launch_bound(), &b_));
thunk_sequence_.emplace_back(std::move(kernel_thunk));

View File

@ -446,11 +446,12 @@ class IrEmitterUnnested : public IrEmitter,
// different. On the other hand, the input ranges of slices can be
// overlapping. Further generalization/specialization when the needs are seen
// in the future.
Status EmitInputFusibleNonStridedSlices(HloInstruction* unnested_hlo);
Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input);
void EmitElementForInputFusibleSlices(
HloInstruction* unnested_hlo,
const llvm_ir::IrArray::Index& slice_input_index);
Status EmitElementForInputFusibleSlices(
mlir::lmhlo::FusionOp fusion,
absl::Span<const llvm_ir::IrArray> ir_arrays,
const llvm_ir::IrArray::Index& index);
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
// the process. Scatter indices are taken from `scatter_indices_gen`, updates

View File

@ -2,21 +2,21 @@
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]*
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]*
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1023 x half]*
// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]*
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [0 x half]*
// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]*
// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [3 x i8*]*
// CHECK: %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024
@ -34,18 +34,18 @@
// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_31]]
// CHECK: %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ]
// CHECK: %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]]
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_2]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4
// CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]]
// CHECK: br label %[[VAL_45:.*]]
// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_37]]
// CHECK: %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ]
// CHECK: %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]]
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4
// CHECK: %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]]
// CHECK: br label %[[VAL_45]]
@ -54,7 +54,7 @@
// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]]
// CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_37]]
// CHECK: unreachable
// CHECK: concat.1.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
// CHECK: concatenate.7.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
// CHECK: %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ]
// CHECK: %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0
// CHECK: %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024
@ -74,17 +74,17 @@
// CHECK: br label %[[VAL_32]]
// CHECK: slice0-true: ; preds = %[[VAL_45]]
// CHECK: %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0
// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]]
// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_71]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_72]], align 2
// CHECK: br label %[[VAL_61]]
// CHECK: slice1-true: ; preds = %[[VAL_61]]
// CHECK: %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]]
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_73]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_74]], align 2
// CHECK: br label %[[VAL_66]]
// CHECK: slice2-true: ; preds = %[[VAL_66]]
// CHECK: %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047
// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]]
// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_75]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_76]], align 2
// CHECK: br label %[[VAL_33]]