diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index ba337e29d30..3d697b78374 100644 --- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -174,11 +174,7 @@ def ICmpPredicate : I64EnumAttr< [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, ICmpPredicateUGT, ICmpPredicateUGE]> { - let cppNamespace = "mlir::LLVM"; - - let returnType = "ICmpPredicate"; - let convertFromStorage = - "static_cast<" # returnType # ">($_self.getValue().getZExtValue())"; + let cppNamespace = "::mlir::LLVM"; } // Other integer operations. @@ -225,11 +221,7 @@ def FCmpPredicate : I64EnumAttr< FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT, FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE ]> { - let cppNamespace = "mlir::LLVM"; - - let returnType = "FCmpPredicate"; - let convertFromStorage = - "static_cast<" # returnType # ">($_self.getValue().getZExtValue())"; + let cppNamespace = "::mlir::LLVM"; } // Other integer operations. diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index f88f87ec54a..07cdd7ac790 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -326,8 +326,6 @@ def SPV_AddressingModelAttr : SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64, SPV_AM_PhysicalStorageBuffer64 ]> { - let returnType = "::mlir::spirv::AddressingModel"; - let convertFromStorage = "static_cast<::mlir::spirv::AddressingModel>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -462,8 +460,6 @@ def SPV_BuiltInAttr : SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV, SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV ]> { - let returnType = "::mlir::spirv::BuiltIn"; - let convertFromStorage = "static_cast<::mlir::spirv::BuiltIn>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -672,8 +668,6 @@ def SPV_CapabilityAttr : SPV_C_SubgroupAvcMotionEstimationIntraINTEL, SPV_C_SubgroupAvcMotionEstimationChromaINTEL ]> { - let returnType = "::mlir::spirv::Capability"; - let convertFromStorage = "static_cast<::mlir::spirv::Capability>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -763,8 +757,6 @@ def SPV_DecorationAttr : SPV_D_AliasedPointer, SPV_D_CounterBuffer, SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE ]> { - let returnType = "::mlir::spirv::Decoration"; - let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -781,8 +773,6 @@ def SPV_DimAttr : SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer, SPV_D_SubpassData ]> { - let returnType = "::mlir::spirv::Dim"; - let convertFromStorage = "static_cast<::mlir::spirv::Dim>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -866,8 +856,6 @@ def SPV_ExecutionModeAttr : SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT, SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT ]> { - let returnType = "::mlir::spirv::ExecutionMode"; - let convertFromStorage = "static_cast<::mlir::spirv::ExecutionMode>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -894,8 +882,6 @@ def SPV_ExecutionModelAttr : SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV, SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV ]> { - let returnType = "::mlir::spirv::ExecutionModel"; - let convertFromStorage = "static_cast<::mlir::spirv::ExecutionModel>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -909,8 +895,6 @@ def SPV_FunctionControlAttr : BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const ]> { - let returnType = "::mlir::spirv::FunctionControl"; - let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -967,8 +951,6 @@ def SPV_ImageFormatAttr : SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui, SPV_IF_R8ui ]> { - let returnType = "::mlir::spirv::ImageFormat"; - let convertFromStorage = "static_cast<::mlir::spirv::ImageFormat>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -979,8 +961,6 @@ def SPV_LinkageTypeAttr : I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [ SPV_LT_Export, SPV_LT_Import ]> { - let returnType = "::mlir::spirv::LinkageType"; - let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1001,8 +981,6 @@ def SPV_LoopControlAttr : SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount ]> { - let returnType = "::mlir::spirv::LoopControl"; - let convertFromStorage = "static_cast<::mlir::spirv::LoopControl>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1020,8 +998,6 @@ def SPV_MemoryAccessAttr : SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible, SPV_MA_NonPrivatePointer ]> { - let returnType = "::mlir::spirv::MemoryAccess"; - let convertFromStorage = "static_cast<::mlir::spirv::MemoryAccess>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1034,8 +1010,6 @@ def SPV_MemoryModelAttr : I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [ SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan ]> { - let returnType = "::mlir::spirv::MemoryModel"; - let convertFromStorage = "static_cast<::mlir::spirv::MemoryModel>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1063,8 +1037,6 @@ def SPV_MemorySemanticsAttr : SPV_MS_AtomicCounterMemory, SPV_MS_ImageMemory, SPV_MS_OutputMemory, SPV_MS_MakeAvailable, SPV_MS_MakeVisible, SPV_MS_Volatile ]> { - let returnType = "::mlir::spirv::MemorySemantics"; - let convertFromStorage = "static_cast<::mlir::spirv::MemorySemantics>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1080,8 +1052,6 @@ def SPV_ScopeAttr : SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup, SPV_S_Invocation, SPV_S_QueueFamily ]> { - let returnType = "::mlir::spirv::Scope"; - let convertFromStorage = "static_cast<::mlir::spirv::Scope>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1093,8 +1063,6 @@ def SPV_SelectionControlAttr : BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [ SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten ]> { - let returnType = "::mlir::spirv::SelectionControl"; - let convertFromStorage = "static_cast<::mlir::spirv::SelectionControl>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } @@ -1128,8 +1096,6 @@ def SPV_StorageClassAttr : SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV, SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBuffer ]> { - let returnType = "::mlir::spirv::StorageClass"; - let convertFromStorage = "static_cast<::mlir::spirv::StorageClass>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } diff --git a/third_party/mlir/include/mlir/IR/OpBase.td b/third_party/mlir/include/mlir/IR/OpBase.td index bfe80e0148b..314acf6653e 100644 --- a/third_party/mlir/include/mlir/IR/OpBase.td +++ b/third_party/mlir/include/mlir/IR/OpBase.td @@ -938,12 +938,18 @@ class IntEnumAttr cases> : IntEnumAttr { + let returnType = cppNamespace # "::" # name; let underlyingType = "uint32_t"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; } class I64EnumAttr cases> : IntEnumAttr { + let returnType = cppNamespace # "::" # name; let underlyingType = "uint64_t"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI64IntegerAttr(static_cast($0))"; } // A bit enum stored with 32-bit IntegerAttr. @@ -963,7 +969,10 @@ class BitEnumAttr ]>; + let returnType = cppNamespace # "::" # name; let underlyingType = "uint32_t"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; // We need to return a string because we may concatenate symbols for multiple // bits together. diff --git a/third_party/mlir/test/BUILD b/third_party/mlir/test/BUILD index 268e5cfc308..85680008d6d 100644 --- a/third_party/mlir/test/BUILD +++ b/third_party/mlir/test/BUILD @@ -44,6 +44,14 @@ gentbl( "-gen-op-defs", "lib/TestDialect/TestOps.cpp.inc", ), + ( + "-gen-enum-decls", + "lib/TestDialect/TestOpEnums.h.inc", + ), + ( + "-gen-enum-defs", + "lib/TestDialect/TestOpEnums.cpp.inc", + ), ( "-gen-rewriters", "lib/TestDialect/TestPatterns.inc", @@ -75,6 +83,7 @@ cc_library( ], deps = [ ":TestOpsIncGen", + "@llvm//:support", "@local_config_mlir//:Analysis", "@local_config_mlir//:Dialect", "@local_config_mlir//:IR", diff --git a/third_party/mlir/test/lib/TestDialect/CMakeLists.txt b/third_party/mlir/test/lib/TestDialect/CMakeLists.txt index a0e0ce0311b..e6a22833de4 100644 --- a/third_party/mlir/test/lib/TestDialect/CMakeLists.txt +++ b/third_party/mlir/test/lib/TestDialect/CMakeLists.txt @@ -6,6 +6,8 @@ set(LLVM_OPTIONAL_SOURCES set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) +mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) +mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TestPatterns.inc -gen-rewriters) add_public_tablegen_target(MLIRTestOpsIncGen) diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp index 3c7fbee3671..60a16d968dc 100644 --- a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -304,5 +305,7 @@ SmallVector mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( // Static initialization for Test dialect registration. static mlir::DialectRegistration testDialect; +#include "TestOpEnums.cpp.inc" + #define GET_OP_CLASSES #include "TestOps.cpp.inc" diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.h b/third_party/mlir/test/lib/TestDialect/TestDialect.h index f10b9845680..783b8a1bcdd 100644 --- a/third_party/mlir/test/lib/TestDialect/TestDialect.h +++ b/third_party/mlir/test/lib/TestDialect/TestDialect.h @@ -32,6 +32,8 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/SymbolTable.h" +#include "TestOpEnums.h.inc" + namespace mlir { class TestDialect : public Dialect { diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td index d804fdc1b78..e8ca8b82487 100644 --- a/third_party/mlir/test/lib/TestDialect/TestOps.td +++ b/third_party/mlir/test/lib/TestDialect/TestOps.td @@ -694,7 +694,7 @@ def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; def MultiResultOpEnum: I64EnumAttr< - "Multi-result op kinds", "", [ + "MultiResultOpEnum", "Multi-result op kinds", [ MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 ]>; diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py index 723a4095c75..9aed98dba70 100755 --- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py +++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py @@ -200,9 +200,6 @@ def gen_operand_kind_enum_attr(operand_kind): enum_attr = 'def SPV_{name}Attr :\n '\ '{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\ ' ]> {{\n'\ - ' let returnType = "::mlir::spirv::{name}";\n'\ - ' let convertFromStorage = '\ - '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ ' let cppNamespace = "::mlir::spirv";\n}}'.format( name=kind_name, category=kind_category, cases=case_names) return kind_name, case_defs + '\n\n' + enum_attr @@ -240,9 +237,6 @@ def gen_opcode(instructions): ' I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ '{lst}\n'\ ' ]> {{\n'\ - ' let returnType = "::mlir::spirv::{name}";\n'\ - ' let convertFromStorage = '\ - '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ ' let cppNamespace = "::mlir::spirv";\n}}'.format( name='Opcode', lst=opcode_list) return opcode_str + '\n\n' + enum_attr