[XLA/GPU] Add LLVM IR FileCheck tests for nested reduce.

PiperOrigin-RevId: 342354037
Change-Id: Icbd12573b07f121aa7be552cd85212d505ce1dfb
This commit is contained in:
Tim Shen 2020-11-13 15:49:03 -08:00 committed by TensorFlower Gardener
parent f0b8fa4d16
commit 134db25ce3
2 changed files with 169 additions and 0 deletions

View File

@ -0,0 +1,169 @@
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = alloca float, align 4
// CHECK: %[[VAL_1:.*]] = alloca float, align 4
// CHECK: %[[VAL_2:.*]] = alloca float, align 4
// CHECK: %[[VAL_3:.*]] = alloca float, align 4
// CHECK: %[[VAL_4:.*]] = alloca float, align 4
// CHECK: %[[VAL_5:.*]] = alloca float, align 4
// CHECK: %[[VAL_6:.*]] = alloca [2 x i8*], align 8
// CHECK: %[[VAL_7:.*]] = alloca i32, align 4
// CHECK: %[[VAL_8:.*]] = alloca i32, align 4
// CHECK: %[[VAL_9:.*]] = alloca float, align 4
// CHECK: %[[VAL_10:.*]] = alloca float, align 4
// CHECK: %[[VAL_11:.*]] = getelementptr inbounds i8, i8* %[[VAL_12:.*]], i64 0
// CHECK: %[[VAL_13:.*]] = bitcast i8* %[[VAL_11]] to [2 x i8*]*
// CHECK: %[[VAL_14:.*]] = getelementptr inbounds i8, i8* %[[VAL_15:.*]], i64 0
// CHECK: %[[VAL_16:.*]] = bitcast i8* %[[VAL_14]] to [200 x float]*
// CHECK: %[[VAL_17:.*]] = getelementptr inbounds i8, i8* %[[VAL_18:.*]], i64 0
// CHECK: %[[VAL_19:.*]] = bitcast i8* %[[VAL_17]] to [200 x float]*
// CHECK: %[[VAL_20:.*]] = getelementptr inbounds i8, i8* %[[VAL_21:.*]], i64 0
// CHECK: %[[VAL_22:.*]] = bitcast i8* %[[VAL_20]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_23:.*]] = getelementptr inbounds i8, i8* %[[VAL_24:.*]], i64 0
// CHECK: %[[VAL_25:.*]] = bitcast i8* %[[VAL_23]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_26:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
// CHECK: %[[VAL_27:.*]] = icmp eq i32 0, %[[VAL_26]]
// CHECK: %[[VAL_28:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
// CHECK: %[[VAL_29:.*]] = icmp eq i32 0, %[[VAL_28]]
// CHECK: %[[VAL_30:.*]] = and i1 %[[VAL_27]], %[[VAL_29]]
// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]]
// CHECK: emit_mof_tuple-after: ; preds = %[[VAL_31]], %[[VAL_33:.*]]
// CHECK: %[[VAL_34:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_35:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_36:.*]] = mul nuw nsw i32 %[[VAL_34]], 200
// CHECK: %[[VAL_37:.*]] = add nuw nsw i32 %[[VAL_36]], %[[VAL_35]]
// CHECK: %[[VAL_38:.*]] = icmp ult i32 %[[VAL_37]], 200
// CHECK: call void @llvm.assume(i1 %[[VAL_38]])
// CHECK: %[[VAL_39:.*]] = udiv i32 %[[VAL_37]], 1
// CHECK: %[[VAL_40:.*]] = icmp ult i32 %[[VAL_37]], 200
// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_42:.*]]
// CHECK: d.in_bounds-after: ; preds = %[[VAL_43:.*]], %[[VAL_32]]
// CHECK: ret void
// CHECK: emit_mof_tuple-true: ; preds = %[[VAL_33]]
// CHECK: %[[VAL_44:.*]] = bitcast [200 x float]* %[[VAL_16]] to i8*
// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 0
// CHECK: store i8* %[[VAL_44]], i8** %[[VAL_45]], align 8
// CHECK: %[[VAL_46:.*]] = bitcast [200 x float]* %[[VAL_19]] to i8*
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_13]], i64 0, i64 1
// CHECK: store i8* %[[VAL_46]], i8** %[[VAL_47]], align 8
// CHECK: br label %[[VAL_32]]
// CHECK: d.in_bounds-true: ; preds = %[[VAL_32]]
// CHECK: %[[VAL_48:.*]] = load float, float* bitcast ([4 x i8]* @buffer_for_c to float*), align 4, !invariant.load !4
// CHECK: store float %[[VAL_48]], float* %[[VAL_10]], align 4
// CHECK: %[[VAL_49:.*]] = load float, float* bitcast ([4 x i8]* @buffer_for_c to float*), align 4, !invariant.load !4
// CHECK: store float %[[VAL_49]], float* %[[VAL_9]], align 4
// CHECK: store i32 0, i32* %[[VAL_8]], align 4
// CHECK: br label %[[VAL_50:.*]]
// CHECK: d.inner.loop_header.reduction_dim.0: ; preds = %[[VAL_51:.*]], %[[VAL_41]]
// CHECK: %[[VAL_52:.*]] = load i32, i32* %[[VAL_8]], align 4
// CHECK: %[[VAL_53:.*]] = icmp uge i32 %[[VAL_52]], 100
// CHECK: br i1 %[[VAL_53]], label %[[VAL_43]], label %[[VAL_54:.*]]
// CHECK: d.inner.loop_body.reduction_dim.0: ; preds = %[[VAL_50]]
// CHECK: store i32 0, i32* %[[VAL_7]], align 4
// CHECK: br label %[[VAL_55:.*]]
// CHECK: d.inner.loop_header.reduction_dim.2: ; preds = %[[VAL_56:.*]], %[[VAL_54]]
// CHECK: %[[VAL_57:.*]] = load i32, i32* %[[VAL_7]], align 4
// CHECK: %[[VAL_58:.*]] = icmp uge i32 %[[VAL_57]], 300
// CHECK: br i1 %[[VAL_58]], label %[[VAL_51]], label %[[VAL_56]]
// CHECK: d.inner.loop_body.reduction_dim.2: ; preds = %[[VAL_55]]
// CHECK: %[[VAL_59:.*]] = load float, float* %[[VAL_10]], align 4
// CHECK: %[[VAL_60:.*]] = load float, float* %[[VAL_9]], align 4
// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_22]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
// CHECK: %[[VAL_62:.*]] = load float, float* %[[VAL_61]], align 4, !invariant.load !4
// CHECK: %[[VAL_63:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_25]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_39]], i32 %[[VAL_57]]
// CHECK: %[[VAL_64:.*]] = load float, float* %[[VAL_63]], align 4, !invariant.load !4
// CHECK: store float %[[VAL_59]], float* %[[VAL_5]], align 4
// CHECK: store float %[[VAL_60]], float* %[[VAL_4]], align 4
// CHECK: store float %[[VAL_62]], float* %[[VAL_3]], align 4
// CHECK: store float %[[VAL_64]], float* %[[VAL_2]], align 4
// CHECK: %[[VAL_65:.*]] = bitcast float* %[[VAL_0]] to i8*
// CHECK: %[[VAL_66:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_6]], i64 0, i64 0
// CHECK: store i8* %[[VAL_65]], i8** %[[VAL_66]], align 8
// CHECK: %[[VAL_67:.*]] = bitcast float* %[[VAL_1]] to i8*
// CHECK: %[[VAL_68:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_6]], i64 0, i64 1
// CHECK: store i8* %[[VAL_67]], i8** %[[VAL_68]], align 8
// CHECK: call void @Add(float* %[[VAL_5]], float* %[[VAL_4]], float* %[[VAL_3]], float* %[[VAL_2]], [2 x i8*]* %[[VAL_6]])
// CHECK: %[[VAL_69:.*]] = load float, float* %[[VAL_0]], align 4
// CHECK: %[[VAL_70:.*]] = load float, float* %[[VAL_1]], align 4
// CHECK: store float %[[VAL_69]], float* %[[VAL_10]], align 4
// CHECK: store float %[[VAL_70]], float* %[[VAL_9]], align 4
// CHECK: %[[VAL_71:.*]] = add nuw nsw i32 %[[VAL_57]], 1
// CHECK: store i32 %[[VAL_71]], i32* %[[VAL_7]], align 4
// CHECK: br label %[[VAL_55]]
// CHECK: d.inner.loop_exit.reduction_dim.2: ; preds = %[[VAL_55]]
// CHECK: %[[VAL_72:.*]] = add nuw nsw i32 %[[VAL_52]], 1
// CHECK: store i32 %[[VAL_72]], i32* %[[VAL_8]], align 4
// CHECK: br label %[[VAL_50]]
// CHECK: d.inner.loop_exit.reduction_dim.0: ; preds = %[[VAL_50]]
// CHECK: %[[VAL_73:.*]] = load float, float* %[[VAL_10]], align 4
// CHECK: %[[VAL_74:.*]] = insertvalue { float, float } undef, float %[[VAL_73]], 0
// CHECK: %[[VAL_75:.*]] = load float, float* %[[VAL_9]], align 4
// CHECK: %[[VAL_76:.*]] = insertvalue { float, float } %[[VAL_74]], float %[[VAL_75]], 1
// CHECK: %[[VAL_77:.*]] = extractvalue { float, float } %[[VAL_76]], 0
// CHECK: %[[VAL_78:.*]] = bitcast [200 x float]* %[[VAL_16]] to float*
// CHECK: %[[VAL_79:.*]] = getelementptr inbounds float, float* %[[VAL_78]], i32 %[[VAL_37]]
// CHECK: store float %[[VAL_77]], float* %[[VAL_79]], align 4
// CHECK: %[[VAL_80:.*]] = extractvalue { float, float } %[[VAL_76]], 1
// CHECK: %[[VAL_81:.*]] = bitcast [200 x float]* %[[VAL_19]] to float*
// CHECK: %[[VAL_82:.*]] = getelementptr inbounds float, float* %[[VAL_81]], i32 %[[VAL_37]]
// CHECK: store float %[[VAL_80]], float* %[[VAL_82]], align 4
// CHECK: br label %[[VAL_42]]
// CHECK: entry:
// CHECK: %[[VAL_83:.*]] = alloca float, align 4
// CHECK: %[[VAL_84:.*]] = alloca float, align 4
// CHECK: %[[VAL_85:.*]] = alloca [2 x i8*], align 8
// CHECK: %[[VAL_86:.*]] = alloca [2 x i8*], align 8
// CHECK: %[[VAL_87:.*]] = alloca [2 x i8*], align 8
// CHECK: %[[VAL_88:.*]] = bitcast [2 x i8*]* %[[VAL_86]] to float*
// CHECK: %[[VAL_89:.*]] = bitcast [2 x i8*]* %[[VAL_85]] to float*
// CHECK: %[[VAL_90:.*]] = load float, float* %[[VAL_91:.*]], align 4
// CHECK: %[[VAL_92:.*]] = load float, float* %[[VAL_93:.*]], align 4
// CHECK: %[[VAL_94:.*]] = fadd float %[[VAL_90]], %[[VAL_92]]
// CHECK: store float %[[VAL_94]], float* %[[VAL_84]], align 4
// CHECK: %[[VAL_95:.*]] = load float, float* %[[VAL_96:.*]], align 4
// CHECK: %[[VAL_97:.*]] = load float, float* %[[VAL_98:.*]], align 4
// CHECK: %[[VAL_99:.*]] = fadd float %[[VAL_95]], %[[VAL_97]]
// CHECK: store float %[[VAL_99]], float* %[[VAL_83]], align 4
// CHECK: %[[VAL_100:.*]] = bitcast float* %[[VAL_84]] to i8*
// CHECK: %[[VAL_101:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_87]], i64 0, i64 0
// CHECK: store i8* %[[VAL_100]], i8** %[[VAL_101]], align 8
// CHECK: %[[VAL_102:.*]] = bitcast float* %[[VAL_83]] to i8*
// CHECK: %[[VAL_103:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_87]], i64 0, i64 1
// CHECK: store i8* %[[VAL_102]], i8** %[[VAL_103]], align 8
// CHECK: %[[VAL_104:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_105:.*]], i64 0, i64 0
// CHECK: %[[VAL_106:.*]] = load i8*, i8** %[[VAL_104]], align 8, !dereferenceable !5, !align !6
// CHECK: %[[VAL_107:.*]] = bitcast i8* %[[VAL_106]] to float*
// CHECK: %[[VAL_108:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_87]], i64 0, i64 0
// CHECK: %[[VAL_109:.*]] = load i8*, i8** %[[VAL_108]], align 8, !dereferenceable !5, !align !6
// CHECK: %[[VAL_110:.*]] = bitcast i8* %[[VAL_109]] to float*
// CHECK: %[[VAL_111:.*]] = load float, float* %[[VAL_110]], align 4
// CHECK: store float %[[VAL_111]], float* %[[VAL_107]], align 4
// CHECK: %[[VAL_112:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_105]], i64 0, i64 1
// CHECK: %[[VAL_113:.*]] = load i8*, i8** %[[VAL_112]], align 8, !dereferenceable !5, !align !6
// CHECK: %[[VAL_114:.*]] = bitcast i8* %[[VAL_113]] to float*
// CHECK: %[[VAL_115:.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* %[[VAL_87]], i64 0, i64 1
// CHECK: %[[VAL_116:.*]] = load i8*, i8** %[[VAL_115]], align 8, !dereferenceable !5, !align !6
// CHECK: %[[VAL_117:.*]] = bitcast i8* %[[VAL_116]] to float*
// CHECK: %[[VAL_118:.*]] = load float, float* %[[VAL_117]], align 4
// CHECK: store float %[[VAL_118]], float* %[[VAL_114]], align 4
// CHECK: ret void
HloModule Test
Add {
scalar_lhs.0 = f32[] parameter(0)
scalar_rhs.0 = f32[] parameter(1)
scalar_lhs.1 = f32[] parameter(2)
scalar_rhs.1 = f32[] parameter(3)
add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
add.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
ROOT t = (f32[], f32[]) tuple(add.0, add.1)
}
ENTRY main {
a = f32[100, 200, 300]{2,1,0} parameter(0)
b = f32[100, 200, 300]{2,1,0} parameter(1)
c = f32[] constant(0)
ROOT d = (f32[200]{0}, f32[200]{0}) reduce(a, b, c, c), dimensions={0,2}, to_apply=Add
}