diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8ce3af4916c..7f011cd9d6f 100644 --- a/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -25,15 +25,27 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +//===----------------------------------------------------------------------===// +// NVVM dialect definitions +//===----------------------------------------------------------------------===// + def NVVM_Dialect : Dialect { let name = "nvvm"; let cppNamespace = "NVVM"; } +//===----------------------------------------------------------------------===// +// NVVM op definitions +//===----------------------------------------------------------------------===// + class NVVM_Op traits = []> : LLVM_OpBase { } +//===----------------------------------------------------------------------===// +// NVVM special register op definitions +//===----------------------------------------------------------------------===// + class NVVM_SpecialRegisterOp traits = []> : NVVM_Op, @@ -44,14 +56,22 @@ class NVVM_SpecialRegisterOpgetOperation()); }]; } +//===----------------------------------------------------------------------===// +// Lane index and range def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">; def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">; + +//===----------------------------------------------------------------------===// +// Thread index and range def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">; def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">; def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">; def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">; def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">; def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">; + +//===----------------------------------------------------------------------===// +// Block index and range def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">; def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">; def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">; @@ -59,6 +79,10 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +//===----------------------------------------------------------------------===// +// NVVM synchronization op definitions +//===----------------------------------------------------------------------===// + def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { string llvmBuilder = [{ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);