[XLA:GPU] Add test for SliceToDynamic IR emission.
- Also fixed IsBlock0Thread0 to split out nested CreateICmpEQ calls out to prevent compiler dependent change in order of the instructions generated. PiperOrigin-RevId: 341860846 Change-Id: Ic65ee73d3b916e925e9745abcfb21186909feed4
This commit is contained in:
parent
2de245e6ee
commit
d75a0910ec
@ -486,13 +486,14 @@ string CudnnConvKindToString(CudnnConvKind kind) {
|
||||
}
|
||||
|
||||
llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
|
||||
return b->CreateAnd(
|
||||
b->CreateICmpEQ(
|
||||
llvm::Value* is_thread0 = b->CreateICmpEQ(
|
||||
b->getInt32(0),
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)),
|
||||
b->CreateICmpEQ(
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
|
||||
|
||||
llvm::Value* is_block0 = b->CreateICmpEQ(
|
||||
b->getInt32(0),
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b)));
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
|
||||
return b->CreateAnd(is_thread0, is_block0);
|
||||
}
|
||||
|
||||
bool AreFusedReductionOutputsConsistent(
|
||||
|
@ -0,0 +1,88 @@
|
||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: entry:
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [2 x [2 x [2 x i32]]]*
|
||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
||||
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [2 x [2 x [2 x i32]]]*
|
||||
// CHECK: %[[VAL_6:.*]] = bitcast [2 x [2 x [2 x i32]]]* %[[VAL_2]] to i8*
|
||||
// CHECK: %[[VAL_7:.*]] = load i32, i32* bitcast ([4 x i8]* @buffer_for_static to i32*), align 4
|
||||
// CHECK: %[[VAL_8:.*]] = load i32, i32* bitcast ([4 x i8]* @buffer_for_dynamic to i32*), align 4
|
||||
// CHECK: %[[VAL_9:.*]] = load i32, i32* bitcast ([4 x i8]* @buffer_for_static to i32*), align 4
|
||||
// CHECK: %[[VAL_10:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
// CHECK: %[[VAL_11:.*]] = icmp eq i32 0, %[[VAL_10]]
|
||||
// CHECK: %[[VAL_12:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
|
||||
// CHECK: %[[VAL_13:.*]] = icmp eq i32 0, %[[VAL_12]]
|
||||
// CHECK: %[[VAL_14:.*]] = and i1 %[[VAL_11]], %[[VAL_13]]
|
||||
// CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]]
|
||||
// CHECK: is_thred_0-after: ; preds = %[[VAL_15]], %[[VAL_17:.*]]
|
||||
// CHECK: %[[VAL_18:.*]] = mul i32 1, %[[VAL_7]]
|
||||
// CHECK: %[[VAL_19:.*]] = mul i32 %[[VAL_18]], %[[VAL_8]]
|
||||
// CHECK: %[[VAL_20:.*]] = mul i32 %[[VAL_19]], %[[VAL_9]]
|
||||
// CHECK: %[[VAL_21:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||
// CHECK: %[[VAL_22:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
||||
// CHECK: %[[VAL_23:.*]] = mul nuw nsw i32 %[[VAL_21]], 8
|
||||
// CHECK: %[[VAL_24:.*]] = add nuw nsw i32 %[[VAL_23]], %[[VAL_22]]
|
||||
// CHECK: %[[VAL_25:.*]] = icmp ult i32 %[[VAL_24]], 8
|
||||
// CHECK: call void @llvm.assume(i1 %[[VAL_25]])
|
||||
// CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_24]], 1
|
||||
// CHECK: %[[VAL_27:.*]] = urem i32 %[[VAL_26]], 2
|
||||
// CHECK: %[[VAL_28:.*]] = udiv i32 %[[VAL_24]], 2
|
||||
// CHECK: %[[VAL_29:.*]] = urem i32 %[[VAL_28]], 2
|
||||
// CHECK: %[[VAL_30:.*]] = udiv i32 %[[VAL_24]], 4
|
||||
// CHECK: %[[VAL_31:.*]] = icmp ult i32 %[[VAL_24]], 8
|
||||
// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]]
|
||||
// CHECK: custom-call.in_bounds-after: ; preds = %[[VAL_34:.*]], %[[VAL_16]]
|
||||
// CHECK: ret void
|
||||
// CHECK: is_thred_0-true: ; preds = %[[VAL_17]]
|
||||
// CHECK: %[[VAL_35:.*]] = getelementptr inbounds i8, i8* %[[VAL_6]], i32 32
|
||||
// CHECK: %[[VAL_36:.*]] = bitcast i8* %[[VAL_35]] to i32*
|
||||
// CHECK: store i32 %[[VAL_7]], i32* %[[VAL_36]], align 4
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_6]], i32 36
|
||||
// CHECK: %[[VAL_38:.*]] = bitcast i8* %[[VAL_37]] to i32*
|
||||
// CHECK: store i32 %[[VAL_8]], i32* %[[VAL_38]], align 4
|
||||
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds i8, i8* %[[VAL_6]], i32 40
|
||||
// CHECK: %[[VAL_40:.*]] = bitcast i8* %[[VAL_39]] to i32*
|
||||
// CHECK: store i32 %[[VAL_9]], i32* %[[VAL_40]], align 4
|
||||
// CHECK: br label %[[VAL_16]]
|
||||
// CHECK: custom-call.in_bounds-true: ; preds = %[[VAL_16]]
|
||||
// CHECK: %[[VAL_41:.*]] = mul nuw nsw i32 %[[VAL_27]], 1
|
||||
// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 0, %[[VAL_41]]
|
||||
// CHECK: %[[VAL_43:.*]] = mul nuw nsw i32 %[[VAL_30]], 2
|
||||
// CHECK: %[[VAL_44:.*]] = add nuw nsw i32 %[[VAL_42]], %[[VAL_43]]
|
||||
// CHECK: %[[VAL_45:.*]] = mul nuw nsw i32 %[[VAL_29]], 4
|
||||
// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_44]], %[[VAL_45]]
|
||||
// CHECK: %[[VAL_47:.*]] = icmp ult i32 %[[VAL_46]], %[[VAL_20]]
|
||||
// CHECK: br i1 %[[VAL_47]], label %[[VAL_48:.*]], label %[[VAL_34]]
|
||||
// CHECK: custom-call.in_dyn_bounds-after: ; preds = %[[VAL_48]], %[[VAL_32]]
|
||||
// CHECK: br label %[[VAL_33]]
|
||||
// CHECK: custom-call.in_dyn_bounds-true: ; preds = %[[VAL_32]]
|
||||
// CHECK: %[[VAL_49:.*]] = udiv i32 %[[VAL_46]], 1
|
||||
// CHECK: %[[VAL_50:.*]] = urem i32 %[[VAL_49]], %[[VAL_9]]
|
||||
// CHECK: %[[VAL_51:.*]] = mul i32 1, %[[VAL_9]]
|
||||
// CHECK: %[[VAL_52:.*]] = udiv i32 %[[VAL_46]], %[[VAL_51]]
|
||||
// CHECK: %[[VAL_53:.*]] = urem i32 %[[VAL_52]], %[[VAL_7]]
|
||||
// CHECK: %[[VAL_54:.*]] = mul i32 %[[VAL_51]], %[[VAL_7]]
|
||||
// CHECK: %[[VAL_55:.*]] = udiv i32 %[[VAL_46]], %[[VAL_54]]
|
||||
// CHECK: %[[VAL_56:.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], [2 x [2 x [2 x i32]]]* %[[VAL_5]], i32 0, i32 %[[VAL_55]], i32 %[[VAL_53]], i32 %[[VAL_50]]
|
||||
// CHECK: %[[VAL_57:.*]] = load i32, i32* %[[VAL_56]], align 4, !invariant.load !4
|
||||
// CHECK: %[[VAL_58:.*]] = bitcast [2 x [2 x [2 x i32]]]* %[[VAL_2]] to i32*
|
||||
// CHECK: %[[VAL_59:.*]] = getelementptr inbounds i32, i32* %[[VAL_58]], i32 %[[VAL_24]]
|
||||
// CHECK: store i32 %[[VAL_57]], i32* %[[VAL_59]], align 4
|
||||
// CHECK: br label %[[VAL_34]]
|
||||
|
||||
HloModule SliceToDynamic
|
||||
|
||||
ENTRY main {
|
||||
%param = s32[2,2,2]{2,0,1} parameter(0)
|
||||
%static = s32[] constant(2)
|
||||
%dynamic = s32[] constant(1)
|
||||
ROOT %custom-call = s32[2,<=2, 2]{2,0,1} custom-call(s32[2,2,2]{2,0,1} %param,
|
||||
s32[] %static,
|
||||
s32[] %dynamic,
|
||||
s32[] %static),
|
||||
custom_call_target="SliceToDynamic",
|
||||
backend_config=""
|
||||
}
|
Loading…
Reference in New Issue
Block a user