diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo new file mode 100644 index 00000000000..aeb91766897 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice.hlo @@ -0,0 +1,115 @@ +// RUN: hlo_to_llvm_ir %s | FileCheck %s + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]* +// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 +// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]* +// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0 +// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]* +// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0 +// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]* +// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0 +// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]* +// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0 +// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]* +// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0 +// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]* +// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0 +// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]* +// CHECK: %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK: %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024 +// CHECK: %[[VAL_27:.*]] = add nuw nsw i32 %[[VAL_26]], %[[VAL_25]] +// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_27]], 2048 +// CHECK: call void @llvm.assume(i1 %[[VAL_28]]) +// CHECK: %[[VAL_29:.*]] = udiv i32 %[[VAL_27]], 1 +// CHECK: %[[VAL_30:.*]] = icmp ult i32 %[[VAL_27]], 2047 +// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]] +// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_33:.*]], %[[VAL_34:.*]] +// CHECK: ret void +// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_34]] +// CHECK: %[[VAL_35:.*]] = icmp ult i32 %[[VAL_29]], 1024 +// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]] +// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_31]] +// CHECK: %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ] +// CHECK: %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]] +// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]] +// CHECK: %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4 +// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]] +// CHECK: %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4 +// CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]] +// CHECK: br label %[[VAL_45:.*]] +// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ] +// CHECK: %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]] +// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]] +// CHECK: %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4 +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]] +// CHECK: %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4 +// CHECK: %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]] +// CHECK: br label %[[VAL_45]] +// CHECK: concat_index_not_from_operand0: ; preds = %[[VAL_31]] +// CHECK: %[[VAL_53:.*]] = icmp ult i32 %[[VAL_29]], 2047 +// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]] +// CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_37]] +// CHECK: unreachable +// CHECK: concat.1.merge: ; preds = %[[VAL_54]], %[[VAL_36]] +// CHECK: %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ] +// CHECK: %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0 +// CHECK: %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024 +// CHECK: %[[VAL_59:.*]] = and i1 %[[VAL_57]], %[[VAL_58]] +// CHECK: br i1 %[[VAL_59]], label %[[VAL_60:.*]], label %[[VAL_61:.*]] +// CHECK: slice0-after: ; preds = %[[VAL_60]], %[[VAL_45]] +// CHECK: %[[VAL_62:.*]] = icmp sge i32 %[[VAL_29]], 1024 +// CHECK: %[[VAL_63:.*]] = icmp slt i32 %[[VAL_29]], 2047 +// CHECK: %[[VAL_64:.*]] = and i1 %[[VAL_62]], %[[VAL_63]] +// CHECK: br i1 %[[VAL_64]], label %[[VAL_65:.*]], label %[[VAL_66:.*]] +// CHECK: slice1-after: ; preds = %[[VAL_65]], %[[VAL_61]] +// CHECK: %[[VAL_67:.*]] = icmp sge i32 %[[VAL_29]], 2047 +// CHECK: %[[VAL_68:.*]] = icmp slt i32 %[[VAL_29]], 2047 +// CHECK: %[[VAL_69:.*]] = and i1 %[[VAL_67]], %[[VAL_68]] +// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_33]] +// CHECK: slice2-after: ; preds = %[[VAL_70]], %[[VAL_66]] +// CHECK: br label %[[VAL_32]] +// CHECK: slice0-true: ; preds = %[[VAL_45]] +// CHECK: %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0 +// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]] +// CHECK: store half %[[VAL_56]], half* %[[VAL_72]], align 2 +// CHECK: br label %[[VAL_61]] +// CHECK: slice1-true: ; preds = %[[VAL_61]] +// CHECK: %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024 +// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]] +// CHECK: store half %[[VAL_56]], half* %[[VAL_74]], align 2 +// CHECK: br label %[[VAL_66]] +// CHECK: slice2-true: ; preds = %[[VAL_66]] +// CHECK: %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047 +// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]] +// CHECK: store half %[[VAL_56]], half* %[[VAL_76]], align 2 +// CHECK: br label %[[VAL_33]] + +HloModule input_fusion_with_a_tuple_of_slices + +fused_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + arg.3 = f16[1023]{0} parameter(2) + arg.4 = f16[1023]{0} parameter(3) + mul.1 = f16[1024]{0} multiply(arg.1, arg.2) + add.1 = f16[1023]{0} add(arg.3, arg.4) + concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0} + slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]} + slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]} + slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]} + ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0}) + tuple(slice.1, slice.2, slice.3) +} + +ENTRY kernel_entry { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + arg.3 = f16[1023]{0} parameter(2) + arg.4 = f16[1023]{0} parameter(3) + ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0}) + fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation +}