[XLA:GPU] Add unit test to test PadToStatic IR emission
PiperOrigin-RevId: 341889831 Change-Id: I9e63d568c64fad69570ecceaa889ea88ad45c49a
This commit is contained in:
parent
6a83a28990
commit
c77657c395
91
tensorflow/compiler/xla/service/gpu/tests/pad_to_static.hlo
Normal file
91
tensorflow/compiler/xla/service/gpu/tests/pad_to_static.hlo
Normal file
@ -0,0 +1,91 @@
|
||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: entry:
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [4 x i8*]*
|
||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
||||
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [2 x [2 x [2 x i32]]]*
|
||||
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
|
||||
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to i32*
|
||||
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
|
||||
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to i32*
|
||||
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
|
||||
// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to i32*
|
||||
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
|
||||
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [2 x [2 x [2 x i32]]]*
|
||||
// CHECK: %[[VAL_18:.*]] = bitcast [2 x [2 x [2 x i32]]]* %[[VAL_17]] to i8*
|
||||
// CHECK: %[[VAL_19:.*]] = getelementptr inbounds i8, i8* %[[VAL_18]], i32 32
|
||||
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_19]] to i32*
|
||||
// CHECK: %[[VAL_21:.*]] = load i32, i32* %[[VAL_20]], align 4
|
||||
// CHECK: %[[VAL_22:.*]] = getelementptr inbounds i8, i8* %[[VAL_18]], i32 36
|
||||
// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_22]] to i32*
|
||||
// CHECK: %[[VAL_24:.*]] = load i32, i32* %[[VAL_23]], align 4
|
||||
// CHECK: %[[VAL_25:.*]] = getelementptr inbounds i8, i8* %[[VAL_18]], i32 40
|
||||
// CHECK: %[[VAL_26:.*]] = bitcast i8* %[[VAL_25]] to i32*
|
||||
// CHECK: %[[VAL_27:.*]] = load i32, i32* %[[VAL_26]], align 4
|
||||
// CHECK: %[[VAL_28:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
// CHECK: %[[VAL_29:.*]] = icmp eq i32 0, %[[VAL_28]]
|
||||
// CHECK: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
|
||||
// CHECK: %[[VAL_31:.*]] = icmp eq i32 0, %[[VAL_30]]
|
||||
// CHECK: %[[VAL_32:.*]] = and i1 %[[VAL_29]], %[[VAL_31]]
|
||||
// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]]
|
||||
// CHECK: is_thred_0-after: ; preds = %[[VAL_33]], %[[VAL_35:.*]]
|
||||
// CHECK: %[[VAL_36:.*]] = mul i32 1, %[[VAL_21]]
|
||||
// CHECK: %[[VAL_37:.*]] = mul i32 %[[VAL_36]], %[[VAL_24]]
|
||||
// CHECK: %[[VAL_38:.*]] = mul i32 %[[VAL_37]], %[[VAL_27]]
|
||||
// CHECK: %[[VAL_39:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_40:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
||||
// CHECK: %[[VAL_41:.*]] = mul nuw nsw i32 %[[VAL_39]], 8
|
||||
// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_41]], %[[VAL_40]]
|
||||
// CHECK: %[[VAL_43:.*]] = icmp ult i32 %[[VAL_42]], 8
|
||||
// CHECK: call void @llvm.assume(i1 %[[VAL_43]])
|
||||
// CHECK: %[[VAL_44:.*]] = udiv i32 %[[VAL_42]], 1
|
||||
// CHECK: %[[VAL_45:.*]] = urem i32 %[[VAL_44]], 2
|
||||
// CHECK: %[[VAL_46:.*]] = udiv i32 %[[VAL_42]], 2
|
||||
// CHECK: %[[VAL_47:.*]] = urem i32 %[[VAL_46]], 2
|
||||
// CHECK: %[[VAL_48:.*]] = udiv i32 %[[VAL_42]], 4
|
||||
// CHECK: %[[VAL_49:.*]] = icmp ult i32 %[[VAL_42]], 8
|
||||
// CHECK: br i1 %[[VAL_49]], label %[[VAL_50:.*]], label %[[VAL_51:.*]]
|
||||
// CHECK: custom-call.2.in_bounds-after: ; preds = %[[VAL_52:.*]], %[[VAL_34]]
|
||||
// CHECK: ret void
|
||||
// CHECK: is_thred_0-true: ; preds = %[[VAL_35]]
|
||||
// CHECK: store i32 %[[VAL_21]], i32* %[[VAL_8]], align 4
|
||||
// CHECK: store i32 %[[VAL_24]], i32* %[[VAL_11]], align 4
|
||||
// CHECK: store i32 %[[VAL_27]], i32* %[[VAL_14]], align 4
|
||||
// CHECK: br label %[[VAL_34]]
|
||||
// CHECK: custom-call.2.in_bounds-true: ; preds = %[[VAL_34]]
|
||||
// CHECK: %[[VAL_53:.*]] = mul nuw nsw i32 %[[VAL_45]], 1
|
||||
// CHECK: %[[VAL_54:.*]] = add nuw nsw i32 0, %[[VAL_53]]
|
||||
// CHECK: %[[VAL_55:.*]] = mul nuw nsw i32 %[[VAL_47]], 2
|
||||
// CHECK: %[[VAL_56:.*]] = add nuw nsw i32 %[[VAL_54]], %[[VAL_55]]
|
||||
// CHECK: %[[VAL_57:.*]] = mul nuw nsw i32 %[[VAL_48]], 4
|
||||
// CHECK: %[[VAL_58:.*]] = add nuw nsw i32 %[[VAL_56]], %[[VAL_57]]
|
||||
// CHECK: %[[VAL_59:.*]] = icmp ult i32 %[[VAL_58]], %[[VAL_38]]
|
||||
// CHECK: br i1 %[[VAL_59]], label %[[VAL_60:.*]], label %[[VAL_52]]
|
||||
// CHECK: custom-call.2.in_dyn_bounds-after: ; preds = %[[VAL_60]], %[[VAL_50]]
|
||||
// CHECK: br label %[[VAL_51]]
|
||||
// CHECK: custom-call.2.in_dyn_bounds-true: ; preds = %[[VAL_50]]
|
||||
// CHECK: %[[VAL_61:.*]] = udiv i32 %[[VAL_58]], 1
|
||||
// CHECK: %[[VAL_62:.*]] = urem i32 %[[VAL_61]], %[[VAL_27]]
|
||||
// CHECK: %[[VAL_63:.*]] = mul i32 1, %[[VAL_27]]
|
||||
// CHECK: %[[VAL_64:.*]] = udiv i32 %[[VAL_58]], %[[VAL_63]]
|
||||
// CHECK: %[[VAL_65:.*]] = urem i32 %[[VAL_64]], %[[VAL_24]]
|
||||
// CHECK: %[[VAL_66:.*]] = mul i32 %[[VAL_63]], %[[VAL_24]]
|
||||
// CHECK: %[[VAL_67:.*]] = udiv i32 %[[VAL_58]], %[[VAL_66]]
|
||||
// CHECK: %[[VAL_68:.*]] = bitcast [2 x [2 x [2 x i32]]]* %[[VAL_17]] to i32*
|
||||
// CHECK: %[[VAL_69:.*]] = getelementptr inbounds i32, i32* %[[VAL_68]], i32 %[[VAL_42]]
|
||||
// CHECK: %[[VAL_70:.*]] = load i32, i32* %[[VAL_69]], align 4, !invariant.load !4
|
||||
// CHECK: %[[VAL_71:.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], [2 x [2 x [2 x i32]]]* %[[VAL_5]], i32 0, i32 %[[VAL_67]], i32 %[[VAL_65]], i32 %[[VAL_62]]
|
||||
// CHECK: store i32 %[[VAL_70]], i32* %[[VAL_71]], align 4
|
||||
// CHECK: br label %[[VAL_52]]
|
||||
|
||||
HloModule PadToStatic
|
||||
|
||||
ENTRY main {
|
||||
%param = s32[2,<=2,2] parameter(0)
|
||||
ROOT %custom-call.2 = (s32[2,2,2], s32[], s32[], s32[])
|
||||
custom-call(s32[2,<=2,2] %param),
|
||||
custom_call_target="PadToStatic"
|
||||
}
|
Loading…
Reference in New Issue
Block a user