diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 0e3bd296a52..7adad73fcb6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -369,29 +369,29 @@ bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce) { return reduction_dimensions.dimensions[1] >= kWarpSize; } -bool IsInputFusibleSlices(const HloInstruction& unnested_hlo, +bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, bool verify_no_strides) { - if (!unnested_hlo.IsInputFusion()) { + auto fusion = mlir::dyn_cast(unnested_hlo); + if (!fusion) { return false; } - auto is_non_strided = [](const std::vector& strides) -> bool { - return absl::c_all_of(strides, [](int stride) { return stride == 1; }); + auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool { + return absl::c_all_of( + strides, [](const llvm::APInt& stride) { return stride == 1; }); }; - const HloInstruction* root = unnested_hlo.fused_expression_root(); - if (root->opcode() == HloOpcode::kSlice) { - return !verify_no_strides || is_non_strided(root->slice_strides()); + for (mlir::Value value : fusion.getFusionResults()) { + auto slice = + mlir::dyn_cast_or_null(value.getDefiningOp()); + if (!slice) { + return false; + } + if (verify_no_strides && !is_non_strided(slice.strides())) { + return false; + } } - - if (root->opcode() != HloOpcode::kTuple) { - return false; - } - - return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kSlice && - (!verify_no_strides || is_non_strided(instr->slice_strides())); - }); + return true; } ReductionDimensions GetReductionKindAndContiguousComponents( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index bda4887b1a1..a7b4432452f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -168,8 +168,8 @@ bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce); // Returns whether unnested_hlo is an input fusion whose root is either a slice // or a tuple of slices. If verify_no_strides is true, returns false unless all // ROOT slices have no strides. -bool IsInputFusibleSlices(const HloInstruction& unnested_hlo, - bool verify_no_strides = false); +bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, + bool verify_no_strides); struct ReductionDimensions { // Indicates whether the reduction is a row reduction or a column reduction. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 82c8db174e6..685976882d5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -490,34 +490,36 @@ llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op, // slices are the same and the slices are non-strided. Otherwise, returns // FailedPrecondition. StatusOr GetConsistentInputShapeForRootSlices( - const HloInstruction& fusion) { + mlir::lmhlo::FusionOp fusion) { if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) { return FailedPrecondition( "Unsupported root for slice input fusion. " "Only non-strided slices are supported."); } - const HloInstruction& root = *fusion.fused_expression_root(); - if (root.opcode() == HloOpcode::kSlice) { - return root.operands()[0]->shape(); - } - - CHECK_EQ(root.opcode(), HloOpcode::kTuple); - const Shape& first_slice_operand_shape = - root.operands()[0]->operands()[0]->shape(); - for (size_t i = 1; i < root.operands().size(); ++i) { - const HloInstruction* slice = root.operands()[i]; - const Shape& operand_shape = slice->operands()[0]->shape(); - if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, - operand_shape)) { - return FailedPrecondition( - "Fused slices do not have the same input shape, fused computation = " - "%s.", - root.parent()->name()); + absl::optional first_slice_operand_shape; + for (mlir::Value result : fusion.getFusionResults()) { + auto slice = + mlir::dyn_cast_or_null(result.getDefiningOp()); + if (!slice) { + return FailedPrecondition("Expected a slice op"); + } + if (first_slice_operand_shape.has_value()) { + Shape operand_shape = TypeToShape(slice.operand().getType()); + if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape, + operand_shape)) { + return FailedPrecondition( + "Fused slices do not have the same input shape, instruction is %s", + MlirToString(fusion)); + } + } else { + first_slice_operand_shape = TypeToShape(slice.operand().getType()); } } - - return first_slice_operand_shape; + if (!first_slice_operand_shape.has_value()) { + return InvalidArgument("Fusion has no roots"); + } + return *first_slice_operand_shape; } } // namespace @@ -1842,8 +1844,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // In the case of root tuple, it can be either reduce or slice input // fusion. case HloOpcode::kTuple: { - if (IsInputFusibleSlices(*fusion)) { - return EmitInputFusibleNonStridedSlices(fusion); + if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) { + return EmitInputFusibleNonStridedSlices(mlir_input); } CHECK_GE(mlir::cast(mlir_input.op) @@ -1864,7 +1866,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return EmitReductionFromOrToContiguousDimensions(mlir_input); } case HloOpcode::kSlice: { - return EmitInputFusibleNonStridedSlices(fusion); + return EmitInputFusibleNonStridedSlices(mlir_input); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -5636,11 +5638,16 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( // Write to output of slice1 // } // -void IrEmitterUnnested::EmitElementForInputFusibleSlices( - HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) { - VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString(); +Status IrEmitterUnnested::EmitElementForInputFusibleSlices( + mlir::lmhlo::FusionOp fusion, absl::Span ir_arrays, + const llvm_ir::IrArray::Index& index) { + VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion); - HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root(); + TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, + GetOrCreateSubComputationFromRegion(&fusion.region(), + /*is_fusion=*/true)); + + HloInstruction* slice_or_tuple = fused_computation->root_instruction(); auto slice_instructions = [&]() -> absl::Span { if (slice_or_tuple->opcode() == HloOpcode::kSlice) { return absl::Span(&slice_or_tuple, 1); @@ -5654,7 +5661,13 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices( GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); FusedIrEmitter fused_emitter(&elem_emitter); - BindFusionArguments(unnested_hlo, &fused_emitter); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + fused_emitter.BindGenerator( + fused_computation->parameter_instruction(i), + [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement(index, &b_); + }); + } for (const HloInstruction* slice : slice_instructions) { auto input_generator = *fused_emitter.GetGenerator(slice->operand(0)); input_ir_values.push_back(input_generator(index).ValueOrDie()); @@ -5689,11 +5702,8 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices( Sub(src_multidim[dim], index.GetConstantWithIndexType(slice->slice_starts(dim))); } - ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice) - ? ShapeIndex() - : ShapeIndex({i}); llvm_ir::IrArray src_ir_array = - GetIrArray(*unnested_hlo, *unnested_hlo, shape_index); + ir_arrays[fused_computation->num_parameters() + i]; IrArray::Index slice_dst_index(dst_multidim, slice->shape(), index.GetType()); src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], @@ -5702,16 +5712,23 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices( ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func); } + return Status::OK(); } Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( - HloInstruction* unnested_hlo) { + MlirEmitterInput mlir_input) { + auto fusion = mlir::cast(mlir_input.op); + constexpr int unroll_factor = 1; - std::unique_ptr kernel_thunk = - BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/true); + + std::vector ir_arrays; + TF_ASSIGN_OR_RETURN( + auto kernel_thunk, + BuildKernelThunkForMlir(fusion, mlir_input.thunk_info, + mlir_input.extra_slice, &ir_arrays)); TF_ASSIGN_OR_RETURN(Shape element_shape, - GetConsistentInputShapeForRootSlices(*unnested_hlo)); + GetConsistentInputShapeForRootSlices(fusion)); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), @@ -5720,13 +5737,12 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( Status emit_status = ParallelLoopEmitter( [&](const llvm_ir::IrArray::Index index) -> Status { - EmitElementForInputFusibleSlices(unnested_hlo, index); - return Status::OK(); + return EmitElementForInputFusibleSlices(fusion, ir_arrays, index); }, element_shape, launch_dimensions, &b_) - .EmitLoop(IrName(unnested_hlo), - GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_)); + .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())), + GetIndexTypeForKernelFromMlir( + fusion, launch_dimensions.launch_bound(), &b_)); thunk_sequence_.emplace_back(std::move(kernel_thunk)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 0480736045e..6cfed8dea33 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -446,11 +446,12 @@ class IrEmitterUnnested : public IrEmitter, // different. On the other hand, the input ranges of slices can be // overlapping. Further generalization/specialization when the needs are seen // in the future. - Status EmitInputFusibleNonStridedSlices(HloInstruction* unnested_hlo); + Status EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input); - void EmitElementForInputFusibleSlices( - HloInstruction* unnested_hlo, - const llvm_ir::IrArray::Index& slice_input_index); + Status EmitElementForInputFusibleSlices( + mlir::lmhlo::FusionOp fusion, + absl::Span ir_arrays, + const llvm_ir::IrArray::Index& index); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. Scatter indices are taken from `scatter_indices_gen`, updates diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo index aeb91766897..5964764a75a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo @@ -2,21 +2,21 @@ // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 -// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]* +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]* // CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 // CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]* // CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0 // CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]* // CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0 -// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]* +// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]* // CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0 // CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]* // CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0 -// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]* +// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1023 x half]* // CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0 -// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]* +// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [0 x half]* // CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0 -// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]* +// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [3 x i8*]* // CHECK: %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 // CHECK: %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024 @@ -34,18 +34,18 @@ // CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_31]] // CHECK: %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ] // CHECK: %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]] -// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]] +// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_2]], i32 0, i32 %[[VAL_39]] // CHECK: %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4 -// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]] +// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_39]] // CHECK: %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4 // CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]] // CHECK: br label %[[VAL_45:.*]] // CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_37]] // CHECK: %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ] // CHECK: %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]] -// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]] +// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_47]] // CHECK: %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4 -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]] +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_47]] // CHECK: %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4 // CHECK: %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]] // CHECK: br label %[[VAL_45]] @@ -54,7 +54,7 @@ // CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]] // CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_37]] // CHECK: unreachable -// CHECK: concat.1.merge: ; preds = %[[VAL_54]], %[[VAL_36]] +// CHECK: concatenate.7.merge: ; preds = %[[VAL_54]], %[[VAL_36]] // CHECK: %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ] // CHECK: %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0 // CHECK: %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024 @@ -74,17 +74,17 @@ // CHECK: br label %[[VAL_32]] // CHECK: slice0-true: ; preds = %[[VAL_45]] // CHECK: %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0 -// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]] +// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_71]] // CHECK: store half %[[VAL_56]], half* %[[VAL_72]], align 2 // CHECK: br label %[[VAL_61]] // CHECK: slice1-true: ; preds = %[[VAL_61]] // CHECK: %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024 -// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]] +// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_73]] // CHECK: store half %[[VAL_56]], half* %[[VAL_74]], align 2 // CHECK: br label %[[VAL_66]] // CHECK: slice2-true: ; preds = %[[VAL_66]] // CHECK: %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047 -// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]] +// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_75]] // CHECK: store half %[[VAL_56]], half* %[[VAL_76]], align 2 // CHECK: br label %[[VAL_33]]