From 2dc7a885a06dababc4cafe719b54a96ebdeab484 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 16 Dec 2019 15:05:21 -0800 Subject: [PATCH] Add atomic operations to SPIR-V dialect. Some changes to the dialect generation script to allow specification of different base class to derive from in ODS. PiperOrigin-RevId: 285859230 Change-Id: I14e96a424e7c63f93435561074ed8b52e4ce78da --- .../mlir/Dialect/SPIRV/SPIRVAtomicOps.td | 495 +++++++++++++++++- .../include/mlir/Dialect/SPIRV/SPIRVBase.td | 24 +- .../GPUToSPIRV/ConvertGPUToSPIRV.cpp | 5 +- .../mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 2 +- .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 65 +++ third_party/mlir/utils/spirv/define_inst.sh | 4 +- .../mlir/utils/spirv/gen_spirv_dialect.py | 7 +- 7 files changed, 585 insertions(+), 17 deletions(-) diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td index d8e62bb13b8..15b6ab0105c 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td @@ -23,6 +23,84 @@ #ifndef SPIRV_ATOMIC_OPS #define SPIRV_ATOMIC_OPS +class SPV_AtomicUpdateOp traits = []> : + SPV_Op { + let parser = [{ return ::parseAtomicUpdateOp(parser, result, false); }]; + let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; + let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$semantics + ); + let results = (outs + SPV_Integer:$result + ); +} + +class SPV_AtomicUpdateWithValueOp traits = []> : + SPV_Op { + let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }]; + let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; + let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_ScopeAttr:$memory_scope, + SPV_MemorySemanticsAttr:$semantics, + SPV_Integer:$value + ); + let results = (outs + SPV_Integer:$result + ); +} + +// ----- + +def SPV_AtomicAndOp : SPV_AtomicUpdateWithValueOp<"AtomicAnd", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by the bitwise AND of Original Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + scope ::= `"CrossDevice"` | `"Device"` | `"Workgroup"` | ... + + memory-semantics ::= `"None"` | `"Acquire"` | "Release"` | ... + + atomic-and-op ::= + `spv.AtomicAnd` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicAnd "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + // ----- def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { @@ -36,10 +114,6 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { ### Custom assembly form ``` - scope ::= `"CrossDevice"` | `"Device"` | `"Workgroup"` | ... - - memory-semantics ::= `"None"` | `"Acquire"` | "Release"` | ... - atomic-compare-exchange-weak-op ::= `spv.AtomicCompareExchangeWeak` scope memory-semantics memory-semantics ssa-use `,` ssa-use `,` ssa-use @@ -71,4 +145,417 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { // ----- +def SPV_AtomicIAddOp : SPV_AtomicUpdateWithValueOp<"AtomicIAdd", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by integer addition of Original Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-iadd-op ::= + `spv.AtomicIAdd` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicIAdd "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicIDecrementOp : SPV_AtomicUpdateOp<"AtomicIDecrement", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value through integer subtraction of 1 from Original Value, + and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. The type of the value + pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-idecrement-op ::= + `spv.AtomicIDecrement` scope memory-semantics ssa-use + `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicIDecrement "Device" "None" %pointer : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicIIncrementOp : SPV_AtomicUpdateOp<"AtomicIIncrement", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value through integer addition of 1 to Original Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. The type of the value + pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-iincrement-op ::= + `spv.AtomicIIncrement` scope memory-semantics ssa-use + `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicIncrement "Device" "None" %pointer : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicISubOp : SPV_AtomicUpdateWithValueOp<"AtomicISub", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by integer subtraction of Value from Original Value, + and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-isub-op ::= + `spv.AtomicISub` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicISub "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicOrOp : SPV_AtomicUpdateWithValueOp<"AtomicOr", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by the bitwise OR of Original Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-or-op ::= + `spv.AtomicOr` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicOr "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicSMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicSMax", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the largest signed integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-smax-op ::= + `spv.AtomicSMax` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicSMax "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicSMinOp : SPV_AtomicUpdateWithValueOp<"AtomicSMin", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the smallest signed integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-smin-op ::= + `spv.AtomicSMin` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicSMin "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the largest unsigned integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-umax-op ::= + `spv.AtomicUMax` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicUMax "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by finding the smallest unsigned integer of Original + Value and Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-umin-op ::= + `spv.AtomicUMin` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicUMin "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + +def SPV_AtomicXorOp : SPV_AtomicUpdateWithValueOp<"AtomicXor", []> { + let summary = [{ + Perform the following steps atomically with respect to any other atomic + accesses within Scope to the same location: + }]; + + let description = [{ + 1) load through Pointer to get an Original Value, + + 2) get a New Value by the bitwise exclusive OR of Original Value and + Value, and + + 3) store the New Value back through Pointer. + + The instruction’s result is the Original Value. + + Result Type must be an integer type scalar. + + The type of Value must be the same as Result Type. The type of the + value pointed to by Pointer must be the same as Result Type. + + Memory must be a valid memory Scope. + + ### Custom assembly form + + ``` + atomic-xor-op ::= + `spv.AtomicXor` scope memory-semantics + ssa-use `,` ssa-use `:` spv-pointer-type + ``` + + For example: + + ``` + %0 = spv.AtomicXor "Device" "None" %pointer, %value : + !spv.ptr + ``` + }]; +} + +// ----- + #endif // SPIRV_ATOMIC_OPS diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 8368a626ffc..838398823ad 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -1144,6 +1144,17 @@ def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>; def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; def SPV_OC_OpAtomicCompareExchangeWeak : I32EnumAttrCase<"OpAtomicCompareExchangeWeak", 231>; +def SPV_OC_OpAtomicIIncrement : I32EnumAttrCase<"OpAtomicIIncrement", 232>; +def SPV_OC_OpAtomicIDecrement : I32EnumAttrCase<"OpAtomicIDecrement", 233>; +def SPV_OC_OpAtomicIAdd : I32EnumAttrCase<"OpAtomicIAdd", 234>; +def SPV_OC_OpAtomicISub : I32EnumAttrCase<"OpAtomicISub", 235>; +def SPV_OC_OpAtomicSMin : I32EnumAttrCase<"OpAtomicSMin", 236>; +def SPV_OC_OpAtomicUMin : I32EnumAttrCase<"OpAtomicUMin", 237>; +def SPV_OC_OpAtomicSMax : I32EnumAttrCase<"OpAtomicSMax", 238>; +def SPV_OC_OpAtomicUMax : I32EnumAttrCase<"OpAtomicUMax", 239>; +def SPV_OC_OpAtomicAnd : I32EnumAttrCase<"OpAtomicAnd", 240>; +def SPV_OC_OpAtomicOr : I32EnumAttrCase<"OpAtomicOr", 241>; +def SPV_OC_OpAtomicXor : I32EnumAttrCase<"OpAtomicXor", 242>; def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>; def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>; @@ -1194,11 +1205,14 @@ def SPV_OpcodeAttr : SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, - SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, - SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, - SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformBallot, - SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, + SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, + SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, + SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, + SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, + SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, + SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR ]> { let cppNamespace = "::mlir::spirv"; } diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index a8747a7a9bf..92cc02660a2 100644 --- a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -234,8 +234,9 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, "lowering as entry functions requires ABI info for all arguments"); return nullptr; } - // For entry functions need to make the signature void(void). Compute the - // replacement value for all arguments and replace all uses. + // Update the signature to valid SPIR-V types and add the ABI + // attributes. These will be "materialized" by using the + // LowerABIAttributesPass. TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); { for (auto argType : enumerate(funcOp.getType().getInputs())) { diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 1e68b49c48b..284fe915029 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -246,7 +246,7 @@ Value *mlir::spirv::getBuiltinVariableValue(Operation *op, } //===----------------------------------------------------------------------===// -// Entry Function signature Conversion +// Set ABI attributes for lowering entry functions. //===----------------------------------------------------------------------===// LogicalResult diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 839f134ec8f..140470b8df8 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -49,6 +49,7 @@ static constexpr const char kIndicesAttrName[] = "indices"; static constexpr const char kInitializerAttrName[] = "initializer"; static constexpr const char kInterfaceAttrName[] = "interface"; static constexpr const char kMemoryScopeAttrName[] = "memory_scope"; +static constexpr const char kSemanticsAttrName[] = "semantics"; static constexpr const char kSpecConstAttrName[] = "spec_const"; static constexpr const char kSpecIdAttrName[] = "spec_id"; static constexpr const char kTypeAttrName[] = "type"; @@ -514,6 +515,70 @@ static LogicalResult verifyBitFieldExtractOp(Operation *op) { return success(); } +// Parses an atomic update op. If the update op does not take a value (like +// AtomicIIncrement) `hasValue` must be false. +static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, + OperationState &state, bool hasValue) { + spirv::Scope scope; + spirv::MemorySemantics memoryScope; + SmallVector operandInfo; + OpAsmParser::OperandType ptrInfo, valueInfo; + Type type; + llvm::SMLoc loc; + if (parseEnumAttribute(scope, parser, state, kMemoryScopeAttrName) || + parseEnumAttribute(memoryScope, parser, state, kSemanticsAttrName) || + parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || + parser.getCurrentLocation(&loc) || parser.parseColonType(type)) + return failure(); + + auto ptrType = type.dyn_cast(); + if (!ptrType) + return parser.emitError(loc, "expected pointer type"); + + SmallVector operandTypes; + operandTypes.push_back(ptrType); + if (hasValue) + operandTypes.push_back(ptrType.getPointeeType()); + if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(), + state.operands)) + return failure(); + return parser.addTypeToList(ptrType.getPointeeType(), state.types); +} + +// Prints an atomic update op. +static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { + printer << op->getName() << " \""; + auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); + printer << spirv::stringifyScope( + static_cast(scopeAttr.getInt())) + << "\" \""; + auto memorySemanticsAttr = op->getAttrOfType(kSemanticsAttrName); + printer << spirv::stringifyMemorySemantics( + static_cast( + memorySemanticsAttr.getInt())) + << "\" " << op->getOperands() << " : " + << op->getOperand(0)->getType(); +} + +// Verifies an atomic update op. +static LogicalResult verifyAtomicUpdateOp(Operation *op) { + auto ptrType = op->getOperand(0)->getType().cast(); + auto elementType = ptrType.getPointeeType(); + if (!elementType.isa()) + return op->emitOpError( + "pointer operand must point to an integer value, found ") + << elementType; + + if (op->getNumOperands() > 1) { + auto valueType = op->getOperand(1)->getType(); + if (valueType != elementType) + return op->emitOpError("expected value to have the same type as the " + "pointer operand's pointee type ") + << elementType << ", but found " << valueType; + } + return success(); +} + // Parses an op that has no inputs and no outputs. static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { if (parser.parseOptionalAttrDict(state.attributes)) diff --git a/third_party/mlir/utils/spirv/define_inst.sh b/third_party/mlir/utils/spirv/define_inst.sh index 3508c4f9b4f..f11078a8e76 100755 --- a/third_party/mlir/utils/spirv/define_inst.sh +++ b/third_party/mlir/utils/spirv/define_inst.sh @@ -35,13 +35,13 @@ file_name=$1 inst_category=$2 case $inst_category in - Op | ArithmeticOp | LogicalOp | CastOp | ControlFlowOp | StructureOp) + Op | ArithmeticOp | LogicalOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp) ;; *) echo "Usage : " $0 " ()*" echo " is the file name of MLIR SPIR-V op definitions spec" echo " must be one of " \ - "(Op|ArithmeticOp|LogicalOp|CastOp|ControlFlowOp|StructureOp)" + "(Op|ArithmeticOp|LogicalOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)" exit 1; ;; esac diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py index bf4886dfd51..be7116c211f 100755 --- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py +++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py @@ -353,7 +353,7 @@ def map_spec_operand_to_ods_argument(operand): # and 'IdScope' given that they should be generated from OpConstant. assert quantifier == '', ('unexpected to have optional/variadic memory ' 'semantics or scope ') - arg_type = 'I32' + arg_type = 'SPV_' + kind[2:] + 'Attr' elif kind == 'LiteralInteger': if quantifier == '': arg_type = 'I32Attr' @@ -651,8 +651,9 @@ def update_td_op_definitions(path, instructions, docs, filter_list, instruction = next( inst for inst in instructions if inst['opname'] == opname) op_defs.append( - get_op_definition(instruction, docs[opname], - op_info_dict.get(opname, {}))) + get_op_definition( + instruction, docs[opname], + op_info_dict.get(opname, {'inst_category': inst_category}))) except StopIteration: # This is an op added by us; use the existing ODS definition. op_defs.append(name_op_map[opname])