[XLA/GPU] Add an LLVM IR test for in-place dynamic update slice.

PiperOrigin-RevId: 354157612
Change-Id: I4bc110f816ba25fbfb338643bcfd894886c6b322
This commit is contained in:
Tim Shen 2021-01-27 13:17:49 -08:00 committed by TensorFlower Gardener
parent 56ecc5e516
commit 150bad0f8b

View File

@ -0,0 +1,69 @@
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [50 x [96 x [1024 x half]]]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_1]], i64 0
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_3]] to [50 x [96 x [1024 x half]]]*
// CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [1 x [96 x [1024 x half]]]*
// CHECK: %[[VAL_8:.*]] = load i32, i32* bitcast ([4 x i8]* @0 to i32*), align 4
// CHECK: %[[VAL_9:.*]] = icmp sge i32 0, %[[VAL_8]]
// CHECK: %[[VAL_10:.*]] = select i1 %[[VAL_9]], i32 0, i32 %[[VAL_8]]
// CHECK: %[[VAL_11:.*]] = icmp sle i32 49, %[[VAL_10]]
// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_11]], i32 49, i32 %[[VAL_10]]
// CHECK: %[[VAL_13:.*]] = load i32, i32* bitcast ([4 x i8]* @1 to i32*), align 4
// CHECK: %[[VAL_14:.*]] = icmp sge i32 0, %[[VAL_13]]
// CHECK: %[[VAL_15:.*]] = select i1 %[[VAL_14]], i32 0, i32 %[[VAL_13]]
// CHECK: %[[VAL_16:.*]] = icmp sle i32 0, %[[VAL_15]]
// CHECK: %[[VAL_17:.*]] = select i1 %[[VAL_16]], i32 0, i32 %[[VAL_15]]
// CHECK: %[[VAL_18:.*]] = load i32, i32* bitcast ([4 x i8]* @2 to i32*), align 4
// CHECK: %[[VAL_19:.*]] = icmp sge i32 0, %[[VAL_18]]
// CHECK: %[[VAL_20:.*]] = select i1 %[[VAL_19]], i32 0, i32 %[[VAL_18]]
// CHECK: %[[VAL_21:.*]] = icmp sle i32 0, %[[VAL_20]]
// CHECK: %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 0, i32 %[[VAL_20]]
// CHECK: %[[VAL_23:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_24:.*]] = zext i32 %[[VAL_23]] to i64
// CHECK: %[[VAL_25:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_26:.*]] = zext i32 %[[VAL_25]] to i64
// CHECK: %[[VAL_27:.*]] = mul nuw nsw i64 %[[VAL_24]], 1024
// CHECK: %[[VAL_28:.*]] = add nuw nsw i64 %[[VAL_27]], %[[VAL_26]]
// CHECK: %[[VAL_29:.*]] = icmp ult i64 %[[VAL_28]], 98304
// CHECK: call void @llvm.assume(i1 %[[VAL_29]])
// CHECK: %[[VAL_30:.*]] = udiv i64 %[[VAL_28]], 1
// CHECK: %[[VAL_31:.*]] = urem i64 %[[VAL_30]], 1024
// CHECK: %[[VAL_32:.*]] = udiv i64 %[[VAL_28]], 1024
// CHECK: %[[VAL_33:.*]] = urem i64 %[[VAL_32]], 96
// CHECK: %[[VAL_34:.*]] = udiv i64 %[[VAL_28]], 98304
// CHECK: %[[VAL_35:.*]] = icmp ult i64 %[[VAL_28]], 98304
// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]]
// CHECK: f1.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
// CHECK: ret void
// CHECK: f1.in_bounds-true: ; preds = %[[VAL_38]]
// CHECK: %[[VAL_39:.*]] = sext i32 %[[VAL_12]] to i64
// CHECK: %[[VAL_40:.*]] = add i64 %[[VAL_39]], %[[VAL_34]]
// CHECK: %[[VAL_41:.*]] = sext i32 %[[VAL_17]] to i64
// CHECK: %[[VAL_42:.*]] = add i64 %[[VAL_41]], %[[VAL_33]]
// CHECK: %[[VAL_43:.*]] = sext i32 %[[VAL_22]] to i64
// CHECK: %[[VAL_44:.*]] = add i64 %[[VAL_43]], %[[VAL_31]]
// CHECK: %[[VAL_45:.*]] = bitcast [1 x [96 x [1024 x half]]]* %[[VAL_7]] to half*
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds half, half* %[[VAL_45]], i64 %[[VAL_28]]
// CHECK: %[[VAL_47:.*]] = load half, half* %[[VAL_46]], align 2, !invariant.load !4
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [50 x [96 x [1024 x half]]], [50 x [96 x [1024 x half]]]* %[[VAL_2]], i64 0, i64 %[[VAL_40]], i64 %[[VAL_42]], i64 %[[VAL_44]]
// CHECK: store half %[[VAL_47]], half* %[[VAL_48]], align 2
// CHECK: br label %[[VAL_37]]
HloModule TestModule
fusion.1 {
p.0 = f16[50,96,1024]{2,1,0} parameter(0)
p.1 = f16[1,96,1024]{2,1,0} parameter(1)
c.0 = s32[] constant(0)
ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0, c.0, c.0)
}
ENTRY entry {
p.0 = f16[50,96,1024]{2,1,0} parameter(0)
p.1 = f16[1,96,1024]{2,1,0} parameter(1)
ROOT f1 = f16[50,96,1024] fusion(p.0, p.1), kind=kLoop, calls=fusion.1
}