Roll-forward with fix:
[XLA/GPU] Migrate fused Slice to take LMHLO. PiperOrigin-RevId: 355491691 Change-Id: I45147b974f4b5e1cbf886945c1078a98e088beba
This commit is contained in:
parent
2ddb6e97f6
commit
2fb379a67c
@ -369,29 +369,29 @@ bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce) {
|
||||
return reduction_dimensions.dimensions[1] >= kWarpSize;
|
||||
}
|
||||
|
||||
bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
|
||||
bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
|
||||
bool verify_no_strides) {
|
||||
if (!unnested_hlo.IsInputFusion()) {
|
||||
auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
|
||||
if (!fusion) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto is_non_strided = [](const std::vector<int64>& strides) -> bool {
|
||||
return absl::c_all_of(strides, [](int stride) { return stride == 1; });
|
||||
auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
|
||||
return absl::c_all_of(
|
||||
strides, [](const llvm::APInt& stride) { return stride == 1; });
|
||||
};
|
||||
|
||||
const HloInstruction* root = unnested_hlo.fused_expression_root();
|
||||
if (root->opcode() == HloOpcode::kSlice) {
|
||||
return !verify_no_strides || is_non_strided(root->slice_strides());
|
||||
for (mlir::Value value : fusion.getFusionResults()) {
|
||||
auto slice =
|
||||
mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
|
||||
if (!slice) {
|
||||
return false;
|
||||
}
|
||||
if (verify_no_strides && !is_non_strided(slice.strides())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) {
|
||||
return instr->opcode() == HloOpcode::kSlice &&
|
||||
(!verify_no_strides || is_non_strided(instr->slice_strides()));
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
ReductionDimensions GetReductionKindAndContiguousComponents(
|
||||
|
@ -168,8 +168,8 @@ bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce);
|
||||
// Returns whether unnested_hlo is an input fusion whose root is either a slice
|
||||
// or a tuple of slices. If verify_no_strides is true, returns false unless all
|
||||
// ROOT slices have no strides.
|
||||
bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
|
||||
bool verify_no_strides = false);
|
||||
bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
|
||||
bool verify_no_strides);
|
||||
|
||||
struct ReductionDimensions {
|
||||
// Indicates whether the reduction is a row reduction or a column reduction.
|
||||
|
@ -490,34 +490,36 @@ llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op,
|
||||
// slices are the same and the slices are non-strided. Otherwise, returns
|
||||
// FailedPrecondition.
|
||||
StatusOr<Shape> GetConsistentInputShapeForRootSlices(
|
||||
const HloInstruction& fusion) {
|
||||
mlir::lmhlo::FusionOp fusion) {
|
||||
if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
|
||||
return FailedPrecondition(
|
||||
"Unsupported root for slice input fusion. "
|
||||
"Only non-strided slices are supported.");
|
||||
}
|
||||
|
||||
const HloInstruction& root = *fusion.fused_expression_root();
|
||||
if (root.opcode() == HloOpcode::kSlice) {
|
||||
return root.operands()[0]->shape();
|
||||
}
|
||||
|
||||
CHECK_EQ(root.opcode(), HloOpcode::kTuple);
|
||||
const Shape& first_slice_operand_shape =
|
||||
root.operands()[0]->operands()[0]->shape();
|
||||
for (size_t i = 1; i < root.operands().size(); ++i) {
|
||||
const HloInstruction* slice = root.operands()[i];
|
||||
const Shape& operand_shape = slice->operands()[0]->shape();
|
||||
if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
|
||||
operand_shape)) {
|
||||
return FailedPrecondition(
|
||||
"Fused slices do not have the same input shape, fused computation = "
|
||||
"%s.",
|
||||
root.parent()->name());
|
||||
absl::optional<Shape> first_slice_operand_shape;
|
||||
for (mlir::Value result : fusion.getFusionResults()) {
|
||||
auto slice =
|
||||
mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(result.getDefiningOp());
|
||||
if (!slice) {
|
||||
return FailedPrecondition("Expected a slice op");
|
||||
}
|
||||
if (first_slice_operand_shape.has_value()) {
|
||||
Shape operand_shape = TypeToShape(slice.operand().getType());
|
||||
if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape,
|
||||
operand_shape)) {
|
||||
return FailedPrecondition(
|
||||
"Fused slices do not have the same input shape, instruction is %s",
|
||||
MlirToString(fusion));
|
||||
}
|
||||
} else {
|
||||
first_slice_operand_shape = TypeToShape(slice.operand().getType());
|
||||
}
|
||||
}
|
||||
|
||||
return first_slice_operand_shape;
|
||||
if (!first_slice_operand_shape.has_value()) {
|
||||
return InvalidArgument("Fusion has no roots");
|
||||
}
|
||||
return *first_slice_operand_shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -1842,8 +1844,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
// In the case of root tuple, it can be either reduce or slice input
|
||||
// fusion.
|
||||
case HloOpcode::kTuple: {
|
||||
if (IsInputFusibleSlices(*fusion)) {
|
||||
return EmitInputFusibleNonStridedSlices(fusion);
|
||||
if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) {
|
||||
return EmitInputFusibleNonStridedSlices(mlir_input);
|
||||
}
|
||||
|
||||
CHECK_GE(mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op)
|
||||
@ -1864,7 +1866,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
return EmitReductionFromOrToContiguousDimensions(mlir_input);
|
||||
}
|
||||
case HloOpcode::kSlice: {
|
||||
return EmitInputFusibleNonStridedSlices(fusion);
|
||||
return EmitInputFusibleNonStridedSlices(mlir_input);
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Bad opcode for input fusion: "
|
||||
@ -5636,11 +5638,16 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
// Write to output of slice1
|
||||
// }
|
||||
//
|
||||
void IrEmitterUnnested::EmitElementForInputFusibleSlices(
|
||||
HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) {
|
||||
VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString();
|
||||
Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
|
||||
mlir::lmhlo::FusionOp fusion, absl::Span<const llvm_ir::IrArray> ir_arrays,
|
||||
const llvm_ir::IrArray::Index& index) {
|
||||
VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion);
|
||||
|
||||
HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root();
|
||||
TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
|
||||
GetOrCreateSubComputationFromRegion(&fusion.region(),
|
||||
/*is_fusion=*/true));
|
||||
|
||||
HloInstruction* slice_or_tuple = fused_computation->root_instruction();
|
||||
auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
|
||||
if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
|
||||
return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
|
||||
@ -5654,7 +5661,13 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
|
||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(&elem_emitter);
|
||||
BindFusionArguments(unnested_hlo, &fused_emitter);
|
||||
for (int i = 0; i < fused_computation->num_parameters(); i++) {
|
||||
fused_emitter.BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(index, &b_);
|
||||
});
|
||||
}
|
||||
for (const HloInstruction* slice : slice_instructions) {
|
||||
auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
|
||||
input_ir_values.push_back(input_generator(index).ValueOrDie());
|
||||
@ -5689,11 +5702,8 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
|
||||
Sub(src_multidim[dim],
|
||||
index.GetConstantWithIndexType(slice->slice_starts(dim)));
|
||||
}
|
||||
ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice)
|
||||
? ShapeIndex()
|
||||
: ShapeIndex({i});
|
||||
llvm_ir::IrArray src_ir_array =
|
||||
GetIrArray(*unnested_hlo, *unnested_hlo, shape_index);
|
||||
ir_arrays[fused_computation->num_parameters() + i];
|
||||
IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
|
||||
index.GetType());
|
||||
src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
|
||||
@ -5702,16 +5712,23 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
|
||||
|
||||
ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
|
||||
HloInstruction* unnested_hlo) {
|
||||
MlirEmitterInput mlir_input) {
|
||||
auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
|
||||
|
||||
constexpr int unroll_factor = 1;
|
||||
std::unique_ptr<KernelThunk> kernel_thunk =
|
||||
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/true);
|
||||
|
||||
std::vector<llvm_ir::IrArray> ir_arrays;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto kernel_thunk,
|
||||
BuildKernelThunkForMlir(fusion, mlir_input.thunk_info,
|
||||
mlir_input.extra_slice, &ir_arrays));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape element_shape,
|
||||
GetConsistentInputShapeForRootSlices(*unnested_hlo));
|
||||
GetConsistentInputShapeForRootSlices(fusion));
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
|
||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||
@ -5720,13 +5737,12 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
|
||||
Status emit_status =
|
||||
ParallelLoopEmitter(
|
||||
[&](const llvm_ir::IrArray::Index index) -> Status {
|
||||
EmitElementForInputFusibleSlices(unnested_hlo, index);
|
||||
return Status::OK();
|
||||
return EmitElementForInputFusibleSlices(fusion, ir_arrays, index);
|
||||
},
|
||||
element_shape, launch_dimensions, &b_)
|
||||
.EmitLoop(IrName(unnested_hlo),
|
||||
GetIndexTypeForKernel(
|
||||
unnested_hlo, launch_dimensions.launch_bound(), &b_));
|
||||
.EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
|
||||
GetIndexTypeForKernelFromMlir(
|
||||
fusion, launch_dimensions.launch_bound(), &b_));
|
||||
|
||||
thunk_sequence_.emplace_back(std::move(kernel_thunk));
|
||||
|
||||
|
@ -446,11 +446,12 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// different. On the other hand, the input ranges of slices can be
|
||||
// overlapping. Further generalization/specialization when the needs are seen
|
||||
// in the future.
|
||||
Status EmitInputFusibleNonStridedSlices(HloInstruction* unnested_hlo);
|
||||
Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input);
|
||||
|
||||
void EmitElementForInputFusibleSlices(
|
||||
HloInstruction* unnested_hlo,
|
||||
const llvm_ir::IrArray::Index& slice_input_index);
|
||||
Status EmitElementForInputFusibleSlices(
|
||||
mlir::lmhlo::FusionOp fusion,
|
||||
absl::Span<const llvm_ir::IrArray> ir_arrays,
|
||||
const llvm_ir::IrArray::Index& index);
|
||||
|
||||
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
|
||||
// the process. Scatter indices are taken from `scatter_indices_gen`, updates
|
||||
|
@ -2,21 +2,21 @@
|
||||
|
||||
// CHECK-LABEL: entry:
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]*
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
|
||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
||||
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
|
||||
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
|
||||
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
|
||||
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
|
||||
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]*
|
||||
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
|
||||
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
|
||||
// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
|
||||
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
|
||||
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
|
||||
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1023 x half]*
|
||||
// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
|
||||
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]*
|
||||
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [0 x half]*
|
||||
// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
|
||||
// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]*
|
||||
// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [3 x i8*]*
|
||||
// CHECK: %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
||||
// CHECK: %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024
|
||||
@ -34,18 +34,18 @@
|
||||
// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_31]]
|
||||
// CHECK: %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ]
|
||||
// CHECK: %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]]
|
||||
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]]
|
||||
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_2]], i32 0, i32 %[[VAL_39]]
|
||||
// CHECK: %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4
|
||||
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]]
|
||||
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_39]]
|
||||
// CHECK: %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4
|
||||
// CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]]
|
||||
// CHECK: br label %[[VAL_45:.*]]
|
||||
// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_37]]
|
||||
// CHECK: %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ]
|
||||
// CHECK: %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]]
|
||||
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]]
|
||||
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_47]]
|
||||
// CHECK: %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4
|
||||
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]]
|
||||
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_47]]
|
||||
// CHECK: %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4
|
||||
// CHECK: %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]]
|
||||
// CHECK: br label %[[VAL_45]]
|
||||
@ -54,7 +54,7 @@
|
||||
// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]]
|
||||
// CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_37]]
|
||||
// CHECK: unreachable
|
||||
// CHECK: concat.1.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
|
||||
// CHECK: concatenate.7.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
|
||||
// CHECK: %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ]
|
||||
// CHECK: %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0
|
||||
// CHECK: %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024
|
||||
@ -74,17 +74,17 @@
|
||||
// CHECK: br label %[[VAL_32]]
|
||||
// CHECK: slice0-true: ; preds = %[[VAL_45]]
|
||||
// CHECK: %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0
|
||||
// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]]
|
||||
// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_71]]
|
||||
// CHECK: store half %[[VAL_56]], half* %[[VAL_72]], align 2
|
||||
// CHECK: br label %[[VAL_61]]
|
||||
// CHECK: slice1-true: ; preds = %[[VAL_61]]
|
||||
// CHECK: %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024
|
||||
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]]
|
||||
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_73]]
|
||||
// CHECK: store half %[[VAL_56]], half* %[[VAL_74]], align 2
|
||||
// CHECK: br label %[[VAL_66]]
|
||||
// CHECK: slice2-true: ; preds = %[[VAL_66]]
|
||||
// CHECK: %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047
|
||||
// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]]
|
||||
// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_75]]
|
||||
// CHECK: store half %[[VAL_56]], half* %[[VAL_76]], align 2
|
||||
// CHECK: br label %[[VAL_33]]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user