[XLA/GPU] Add an LLVM IR test for fused slices.

PiperOrigin-RevId: 355427820
Change-Id: Ie97e17df23676db41a160677a4a79eaacb028a10
This commit is contained in:
Tim Shen 2021-02-03 10:33:33 -08:00 committed by TensorFlower Gardener
parent 3d7a7d2516
commit 8ddc7459a8

View File

@ -0,0 +1,115 @@
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x i8*]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [0 x half]*
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1023 x half]*
// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [1023 x half]*
// CHECK: %[[VAL_24:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_26:.*]] = mul nuw nsw i32 %[[VAL_24]], 1024
// CHECK: %[[VAL_27:.*]] = add nuw nsw i32 %[[VAL_26]], %[[VAL_25]]
// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_27]], 2048
// CHECK: call void @llvm.assume(i1 %[[VAL_28]])
// CHECK: %[[VAL_29:.*]] = udiv i32 %[[VAL_27]], 1
// CHECK: %[[VAL_30:.*]] = icmp ult i32 %[[VAL_27]], 2047
// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]]
// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_33:.*]], %[[VAL_34:.*]]
// CHECK: ret void
// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_34]]
// CHECK: %[[VAL_35:.*]] = icmp ult i32 %[[VAL_29]], 1024
// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]]
// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_31]]
// CHECK: %[[VAL_38:.*]] = phi i32 [ 0, %[[VAL_31]] ]
// CHECK: %[[VAL_39:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_38]]
// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_41:.*]] = load half, half* %[[VAL_40]], align 2, !invariant.load !4
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_39]]
// CHECK: %[[VAL_43:.*]] = load half, half* %[[VAL_42]], align 2, !invariant.load !4
// CHECK: %[[VAL_44:.*]] = fmul half %[[VAL_41]], %[[VAL_43]]
// CHECK: br label %[[VAL_45:.*]]
// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_37]]
// CHECK: %[[VAL_46:.*]] = phi i32 [ 1024, %[[VAL_37]] ]
// CHECK: %[[VAL_47:.*]] = sub nsw i32 %[[VAL_29]], %[[VAL_46]]
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_49:.*]] = load half, half* %[[VAL_48]], align 2, !invariant.load !4
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_47]]
// CHECK: %[[VAL_51:.*]] = load half, half* %[[VAL_50]], align 2, !invariant.load !4
// CHECK: %[[VAL_52:.*]] = fadd half %[[VAL_49]], %[[VAL_51]]
// CHECK: br label %[[VAL_45]]
// CHECK: concat_index_not_from_operand0: ; preds = %[[VAL_31]]
// CHECK: %[[VAL_53:.*]] = icmp ult i32 %[[VAL_29]], 2047
// CHECK: br i1 %[[VAL_53]], label %[[VAL_54:.*]], label %[[VAL_55:.*]]
// CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_37]]
// CHECK: unreachable
// CHECK: concat.1.merge: ; preds = %[[VAL_54]], %[[VAL_36]]
// CHECK: %[[VAL_56:.*]] = phi half [ %[[VAL_44]], %[[VAL_36]] ], [ %[[VAL_52]], %[[VAL_54]] ]
// CHECK: %[[VAL_57:.*]] = icmp sge i32 %[[VAL_29]], 0
// CHECK: %[[VAL_58:.*]] = icmp slt i32 %[[VAL_29]], 1024
// CHECK: %[[VAL_59:.*]] = and i1 %[[VAL_57]], %[[VAL_58]]
// CHECK: br i1 %[[VAL_59]], label %[[VAL_60:.*]], label %[[VAL_61:.*]]
// CHECK: slice0-after: ; preds = %[[VAL_60]], %[[VAL_45]]
// CHECK: %[[VAL_62:.*]] = icmp sge i32 %[[VAL_29]], 1024
// CHECK: %[[VAL_63:.*]] = icmp slt i32 %[[VAL_29]], 2047
// CHECK: %[[VAL_64:.*]] = and i1 %[[VAL_62]], %[[VAL_63]]
// CHECK: br i1 %[[VAL_64]], label %[[VAL_65:.*]], label %[[VAL_66:.*]]
// CHECK: slice1-after: ; preds = %[[VAL_65]], %[[VAL_61]]
// CHECK: %[[VAL_67:.*]] = icmp sge i32 %[[VAL_29]], 2047
// CHECK: %[[VAL_68:.*]] = icmp slt i32 %[[VAL_29]], 2047
// CHECK: %[[VAL_69:.*]] = and i1 %[[VAL_67]], %[[VAL_68]]
// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_33]]
// CHECK: slice2-after: ; preds = %[[VAL_70]], %[[VAL_66]]
// CHECK: br label %[[VAL_32]]
// CHECK: slice0-true: ; preds = %[[VAL_45]]
// CHECK: %[[VAL_71:.*]] = sub i32 %[[VAL_29]], 0
// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_71]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_72]], align 2
// CHECK: br label %[[VAL_61]]
// CHECK: slice1-true: ; preds = %[[VAL_61]]
// CHECK: %[[VAL_73:.*]] = sub i32 %[[VAL_29]], 1024
// CHECK: %[[VAL_74:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_73]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_74]], align 2
// CHECK: br label %[[VAL_66]]
// CHECK: slice2-true: ; preds = %[[VAL_66]]
// CHECK: %[[VAL_75:.*]] = sub i32 %[[VAL_29]], 2047
// CHECK: %[[VAL_76:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_75]]
// CHECK: store half %[[VAL_56]], half* %[[VAL_76]], align 2
// CHECK: br label %[[VAL_33]]
HloModule input_fusion_with_a_tuple_of_slices
fused_computation {
arg.1 = f16[1024]{0} parameter(0)
arg.2 = f16[1024]{0} parameter(1)
arg.3 = f16[1023]{0} parameter(2)
arg.4 = f16[1023]{0} parameter(3)
mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
add.1 = f16[1023]{0} add(arg.3, arg.4)
concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0}
slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]}
slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]}
slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]}
ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
tuple(slice.1, slice.2, slice.3)
}
ENTRY kernel_entry {
arg.1 = f16[1024]{0} parameter(0)
arg.2 = f16[1024]{0} parameter(1)
arg.3 = f16[1023]{0} parameter(2)
arg.4 = f16[1023]{0} parameter(3)
ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation
}