diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index d2c141bf231..7d47aae2270 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -409,6 +409,7 @@ cc_library( "include/mlir/Support/MathExtras.h", "include/mlir/Support/STLExtras.h", "include/mlir/Support/StorageUniquer.h", + "include/mlir/Support/StringExtras.h", ], includes = ["include"], deps = [ diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index a12f3390125..a820c11dbdb 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -196,6 +196,97 @@ def SPV_AddressingModelAttr : let cppNamespace = "::mlir::spirv"; } +def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>; +def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>; +def SPV_D_Block : I32EnumAttrCase<"Block", 2>; +def SPV_D_BufferBlock : I32EnumAttrCase<"BufferBlock", 3>; +def SPV_D_RowMajor : I32EnumAttrCase<"RowMajor", 4>; +def SPV_D_ColMajor : I32EnumAttrCase<"ColMajor", 5>; +def SPV_D_ArrayStride : I32EnumAttrCase<"ArrayStride", 6>; +def SPV_D_MatrixStride : I32EnumAttrCase<"MatrixStride", 7>; +def SPV_D_GLSLShared : I32EnumAttrCase<"GLSLShared", 8>; +def SPV_D_GLSLPacked : I32EnumAttrCase<"GLSLPacked", 9>; +def SPV_D_CPacked : I32EnumAttrCase<"CPacked", 10>; +def SPV_D_BuiltIn : I32EnumAttrCase<"BuiltIn", 11>; +def SPV_D_NoPerspective : I32EnumAttrCase<"NoPerspective", 13>; +def SPV_D_Flat : I32EnumAttrCase<"Flat", 14>; +def SPV_D_Patch : I32EnumAttrCase<"Patch", 15>; +def SPV_D_Centroid : I32EnumAttrCase<"Centroid", 16>; +def SPV_D_Sample : I32EnumAttrCase<"Sample", 17>; +def SPV_D_Invariant : I32EnumAttrCase<"Invariant", 18>; +def SPV_D_Restrict : I32EnumAttrCase<"Restrict", 19>; +def SPV_D_Aliased : I32EnumAttrCase<"Aliased", 20>; +def SPV_D_Volatile : I32EnumAttrCase<"Volatile", 21>; +def SPV_D_Constant : I32EnumAttrCase<"Constant", 22>; +def SPV_D_Coherent : I32EnumAttrCase<"Coherent", 23>; +def SPV_D_NonWritable : I32EnumAttrCase<"NonWritable", 24>; +def SPV_D_NonReadable : I32EnumAttrCase<"NonReadable", 25>; +def SPV_D_Uniform : I32EnumAttrCase<"Uniform", 26>; +def SPV_D_UniformId : I32EnumAttrCase<"UniformId", 27>; +def SPV_D_SaturatedConversion : I32EnumAttrCase<"SaturatedConversion", 28>; +def SPV_D_Stream : I32EnumAttrCase<"Stream", 29>; +def SPV_D_Location : I32EnumAttrCase<"Location", 30>; +def SPV_D_Component : I32EnumAttrCase<"Component", 31>; +def SPV_D_Index : I32EnumAttrCase<"Index", 32>; +def SPV_D_Binding : I32EnumAttrCase<"Binding", 33>; +def SPV_D_DescriptorSet : I32EnumAttrCase<"DescriptorSet", 34>; +def SPV_D_Offset : I32EnumAttrCase<"Offset", 35>; +def SPV_D_XfbBuffer : I32EnumAttrCase<"XfbBuffer", 36>; +def SPV_D_XfbStride : I32EnumAttrCase<"XfbStride", 37>; +def SPV_D_FuncParamAttr : I32EnumAttrCase<"FuncParamAttr", 38>; +def SPV_D_FPRoundingMode : I32EnumAttrCase<"FPRoundingMode", 39>; +def SPV_D_FPFastMathMode : I32EnumAttrCase<"FPFastMathMode", 40>; +def SPV_D_LinkageAttributes : I32EnumAttrCase<"LinkageAttributes", 41>; +def SPV_D_NoContraction : I32EnumAttrCase<"NoContraction", 42>; +def SPV_D_InputAttachmentIndex : I32EnumAttrCase<"InputAttachmentIndex", 43>; +def SPV_D_Alignment : I32EnumAttrCase<"Alignment", 44>; +def SPV_D_MaxByteOffset : I32EnumAttrCase<"MaxByteOffset", 45>; +def SPV_D_AlignmentId : I32EnumAttrCase<"AlignmentId", 46>; +def SPV_D_MaxByteOffsetId : I32EnumAttrCase<"MaxByteOffsetId", 47>; +def SPV_D_NoSignedWrap : I32EnumAttrCase<"NoSignedWrap", 4469>; +def SPV_D_NoUnsignedWrap : I32EnumAttrCase<"NoUnsignedWrap", 4470>; +def SPV_D_ExplicitInterpAMD : I32EnumAttrCase<"ExplicitInterpAMD", 4999>; +def SPV_D_OverrideCoverageNV : I32EnumAttrCase<"OverrideCoverageNV", 5248>; +def SPV_D_PassthroughNV : I32EnumAttrCase<"PassthroughNV", 5250>; +def SPV_D_ViewportRelativeNV : I32EnumAttrCase<"ViewportRelativeNV", 5252>; +def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>; +def SPV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271>; +def SPV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272>; +def SPV_D_PerTaskNV : I32EnumAttrCase<"PerTaskNV", 5273>; +def SPV_D_PerVertexNV : I32EnumAttrCase<"PerVertexNV", 5285>; +def SPV_D_NonUniformEXT : I32EnumAttrCase<"NonUniformEXT", 5300>; +def SPV_D_RestrictPointerEXT : I32EnumAttrCase<"RestrictPointerEXT", 5355>; +def SPV_D_AliasedPointerEXT : I32EnumAttrCase<"AliasedPointerEXT", 5356>; +def SPV_D_CounterBuffer : I32EnumAttrCase<"CounterBuffer", 5634>; +def SPV_D_UserSemantic : I32EnumAttrCase<"UserSemantic", 5635>; +def SPV_D_UserTypeGOOGLE : I32EnumAttrCase<"UserTypeGOOGLE", 5636>; + +def SPV_DecorationAttr : + I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [ + SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock, + SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride, + SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn, + SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample, + SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant, + SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform, + SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location, + SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset, + SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode, + SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction, + SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset, + SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap, + SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV, + SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV, + SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV, + SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniformEXT, + SPV_D_RestrictPointerEXT, SPV_D_AliasedPointerEXT, SPV_D_CounterBuffer, + SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE + ]> { + let returnType = "::mlir::spirv::Decoration"; + let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())"; + let cppNamespace = "::mlir::spirv"; +} + def SPV_D_1D : I32EnumAttrCase<"1D", 0>; def SPV_D_2D : I32EnumAttrCase<"2D", 1>; def SPV_D_3D : I32EnumAttrCase<"3D", 2>; diff --git a/third_party/mlir/include/mlir/Support/StringExtras.h b/third_party/mlir/include/mlir/Support/StringExtras.h new file mode 100644 index 00000000000..a5ec73275ff --- /dev/null +++ b/third_party/mlir/include/mlir/Support/StringExtras.h @@ -0,0 +1,81 @@ +//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains string utility functions used within MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_STRINGEXTRAS_H +#define MLIR_SUPPORT_STRINGEXTRAS_H + +#include "llvm/ADT/StringExtras.h" + +namespace mlir { +/// Converts a string to snake-case from camel-case by replacing all uppercase +/// letters with '_' followed by the letter in lowercase, except if the +/// uppercase letter is the first character of the string. +inline std::string convertToSnakeCase(llvm::StringRef input) { + std::string snakeCase; + snakeCase.reserve(input.size()); + for (auto c : input) { + if (std::isupper(c)) { + if (!snakeCase.empty() && snakeCase.back() != '_') { + snakeCase.push_back('_'); + } + snakeCase.push_back(llvm::toLower(c)); + } else { + snakeCase.push_back(c); + } + } + return snakeCase; +} + +/// Converts a string from camel-case to snake_case by replacing all occurences +/// of '_' followed by a lowercase letter with the letter in +/// uppercase. Optionally allow capitalization of the first letter (if it is a +/// lowercase letter) +inline std::string convertToCamelCase(llvm::StringRef input, + bool capitalizeFirst = false) { + if (input.empty()) { + return ""; + } + std::string output; + output.reserve(input.size()); + size_t pos = 0; + if (capitalizeFirst && std::islower(input[pos])) { + output.push_back(llvm::toUpper(input[pos])); + pos++; + } + while (pos < input.size()) { + auto cur = input[pos]; + if (cur == '_') { + if (pos && (pos + 1 < input.size())) { + if (std::islower(input[pos + 1])) { + output.push_back(llvm::toUpper(input[pos + 1])); + pos += 2; + continue; + } + } + } + output.push_back(cur); + pos++; + } + return output; +} +} // namespace mlir + +#endif // MLIR_SUPPORT_STRINGEXTRAS_H diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ae5752a396e..05a1746bccd 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -26,13 +26,12 @@ #include "mlir/IR/Function.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/StringExtras.h" using namespace mlir; // TODO(antiagainst): generate these strings using ODS. static constexpr const char kAlignmentAttrName[] = "alignment"; -static constexpr const char kBindingAttrName[] = "binding"; -static constexpr const char kDescriptorSetAttrName[] = "descriptor_set"; static constexpr const char kIndicesAttrName[] = "indices"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; @@ -67,8 +66,7 @@ static LogicalResult extractValueFromConstOp(Operation *op, } template -static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, - OperationState *state) { +static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) { Attribute attrVal; SmallVector attr; auto loc = parser->getCurrentLocation(); @@ -89,6 +87,15 @@ static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, << " attribute specification: " << attrVal; } value = attrOptional.getValue(); + return success(); +} + +template +static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, + OperationState *state) { + if (parseEnumAttribute(value, parser)) { + return failure(); + } state->addAttribute( spirv::attributeName(), parser->getBuilder().getI32IntegerAttr(bitwiseCast(value))); @@ -601,7 +608,7 @@ static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) { spirv::StorageClass storageClass; OpAsmParser::OperandType ptrInfo; Type elementType; - if (parseEnumAttribute(storageClass, parser, state) || + if (parseEnumAttribute(storageClass, parser) || parser->parseOperand(ptrInfo) || parseMemoryAccessAttributes(parser, state) || parser->parseOptionalAttributeDict(state->attributes) || @@ -813,7 +820,7 @@ static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) { SmallVector operandInfo; auto loc = parser->getCurrentLocation(); Type elementType; - if (parseEnumAttribute(storageClass, parser, state) || + if (parseEnumAttribute(storageClass, parser) || parser->parseOperandList(operandInfo, 2) || parseMemoryAccessAttributes(parser, state) || parser->parseColon() || parser->parseType(elementType)) { @@ -873,13 +880,17 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { // Parse optional descriptor binding Attribute set, binding; + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); if (succeeded(parser->parseOptionalKeyword("bind"))) { Type i32Type = parser->getBuilder().getIntegerType(32); if (parser->parseLParen() || - parser->parseAttribute(set, i32Type, kDescriptorSetAttrName, + parser->parseAttribute(set, i32Type, descriptorSetName, state->attributes) || parser->parseComma() || - parser->parseAttribute(binding, i32Type, kBindingAttrName, + parser->parseAttribute(binding, i32Type, bindingName, state->attributes) || parser->parseRParen()) return failure(); @@ -931,12 +942,17 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { } // Print optional descriptor binding - auto set = varOp.getAttrOfType(kDescriptorSetAttrName); - auto binding = varOp.getAttrOfType(kBindingAttrName); - if (set && binding) { - elidedAttrs.push_back(kDescriptorSetAttrName); - elidedAttrs.push_back(kBindingAttrName); - *printer << " bind(" << set.getInt() << ", " << binding.getInt() << ")"; + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSet = varOp.getAttrOfType(descriptorSetName); + auto binding = varOp.getAttrOfType(bindingName); + if (descriptorSet && binding) { + elidedAttrs.push_back(descriptorSetName); + elidedAttrs.push_back(bindingName); + *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() + << ")"; } printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 4c35f6adf91..2ca8f4578e3 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StringExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" @@ -80,6 +81,9 @@ private: /// Process SPIR-V OpName with `operands` LogicalResult processName(ArrayRef operands); + /// Method to process an OpDecorate instruction. + LogicalResult processDecoration(ArrayRef words); + /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. /// This method processes each instruction inside the function and dispatches @@ -196,6 +200,9 @@ private: // Result to name mapping. DenseMap nameMap; + // Result to decorations mapping. + DenseMap decorations; + // List of instructions that are processed in a defered fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function @@ -285,6 +292,37 @@ LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processDecoration(ArrayRef words) { + // TODO : This function should also be auto-generated. For now, since only a + // few decorations are processed/handled in a meaningful manner, going with a + // manual implementation. + if (words.size() < 2) { + return emitError( + unknownLoc, "OpDecorate must have at least result and Decoration"); + } + auto decorationName = + stringifyDecoration(static_cast(words[1])); + if (decorationName.empty()) { + return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; + } + auto attrName = convertToSnakeCase(decorationName); + switch (static_cast(words[1])) { + case spirv::Decoration::DescriptorSet: + case spirv::Decoration::Binding: + if (words.size() != 3) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName << " needs a single integer literal"; + } + decorations[words[0]].set( + opBuilder.getIdentifier(attrName), + opBuilder.getI32IntegerAttr(static_cast(words[2]))); + break; + default: + return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; + } + return success(); +} + LogicalResult Deserializer::processFunction(ArrayRef operands) { // Get the result type if (operands.size() != 4) { @@ -830,6 +868,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processConstantBool(false, operands); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); + case spirv::Opcode::OpDecorate: + return processDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); default: @@ -839,6 +879,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, } namespace { + template <> LogicalResult Deserializer::processOp(ArrayRef words) { diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 7030bd9e71b..35c4088fa0a 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StringExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/raw_ostream.h" @@ -127,6 +128,10 @@ private: /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); + /// Process attributes that translate to decorations on the result + LogicalResult processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr); + //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// @@ -319,6 +324,34 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { return failure(); } +LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr) { + auto attrName = attr.first.strref(); + auto decorationName = mlir::convertToCamelCase(attrName, true); + auto decoration = spirv::symbolizeDecoration(decorationName); + if (!decoration) { + return emitError( + loc, "non-argument attributes expected to have snake-case-ified " + "decoration name, unhandled attribute with name : ") + << attrName; + } + SmallVector args; + args.push_back(resultID); + args.push_back(static_cast(decoration.getValue())); + switch (decoration.getValue()) { + case spirv::Decoration::DescriptorSet: + case spirv::Decoration::Binding: + if (auto intAttr = attr.second.dyn_cast()) { + args.push_back(intAttr.getValue().getZExtValue()); + break; + } + return emitError(loc, "expected integer attribute for ") << attrName; + default: + return emitError(loc, "unhandled decoration ") << decorationName; + } + return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); +} + LogicalResult Serializer::processFuncOp(FuncOp op) { uint32_t fnTypeID = 0; // Generate type of the function. diff --git a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 0c177204712..75da5e7996c 100644 --- a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -20,9 +20,11 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -39,6 +41,8 @@ using llvm::raw_string_ostream; using llvm::Record; using llvm::RecordKeeper; using llvm::SMLoc; +using llvm::StringRef; +using llvm::Twine; using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; using mlir::tblgen::NamedAttribute; @@ -90,7 +94,8 @@ static void emitAttributeSerialization(const Attribute &attr, os << " }\n"; } -static void emitSerializationFunction(const Record *record, const Operator &op, +static void emitSerializationFunction(const Record *attrClass, + const Record *record, const Operator &op, raw_ostream &os) { // If the record has 'autogenSerialization' set to 0, nothing to do if (!record->getValueAsBit("autogenSerialization")) { @@ -101,21 +106,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op, op.getQualCppClassName()) << " {\n"; os << " SmallVector operands;\n"; + os << " SmallVector elidedAttrs;\n"; // Serialize result information if (op.getNumResults() == 1) { - os << " {\n"; - os << " uint32_t typeID = 0;\n"; - os << " if (failed(processType(op.getLoc(), " - "op.getResult()->getType(), typeID))) {\n"; - os << " return failure();\n"; - os << " }\n"; - os << " operands.push_back(typeID);\n"; - /// Create an SSA result for the op - os << " auto resultID = getNextID();\n"; - os << " valueIDMap[op.getResult()] = resultID;\n"; - os << " operands.push_back(resultID);\n"; + os << " uint32_t resultTypeID = 0;\n"; + os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) " + "{\n"; + os << " return failure();\n"; os << " }\n"; + os << " operands.push_back(resultTypeID);\n"; + // Create an SSA result for the op + os << " auto resultID = getNextID();\n"; + os << " valueIDMap[op.getResult()] = resultID;\n"; + os << " operands.push_back(resultID);\n"; } else if (op.getNumResults() != 0) { PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result"); } @@ -140,6 +144,7 @@ static void emitSerializationFunction(const Record *record, const Operator &op, emitAttributeSerialization( (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), record->getLoc(), "op", "operands", attr->name, os); + os << " elidedAttrs.push_back(\"" << attr->name << "\");\n"; } os << " }\n"; } @@ -147,6 +152,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op, os << formatv(" encodeInstructionInto(" "functions, spirv::getOpcode<{0}>(), operands);\n", op.getQualCppClassName()); + + if (op.getNumResults() == 1) { + // All non-argument attributes translated into OpDecorate instruction + os << " for (auto attr : op.getAttrs()) {\n"; + os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return " + "attr.first.is(elided); })) {\n"; + os << " continue;\n"; + os << " }\n"; + os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n"; + os << " return failure();"; + os << " }\n"; + os << " }\n"; + } + os << " return success();\n"; os << "}\n\n"; } @@ -196,7 +215,8 @@ static void emitAttributeDeserialization( } } -static void emitDeserializationFunction(const Record *record, +static void emitDeserializationFunction(const Record *attrClass, + const Record *record, const Operator &op, raw_ostream &os) { // If the record has 'autogenSerialization' set to 0, nothing to do if (!record->getValueAsBit("autogenSerialization")) { @@ -292,8 +312,19 @@ static void emitDeserializationFunction(const Record *record, "operands, attributes); (void)op;\n", op.getQualCppClassName()); if (hasResult) { - os << " valueMap[valueID] = op.getResult();\n"; + os << " valueMap[valueID] = op.getResult();\n\n"; } + + // Import decorations parsed + if (op.getNumResults() == 1) { + os << " if (decorations.count(valueID)) {\n"; + os << " auto decorationAttrs = decorations[valueID];\n"; + os << " for (auto attr : decorationAttrs.getAttrs()) {\n"; + os << " op.setAttr(attr.first, attr.second);\n"; + os << " }\n"; + os << " }\n"; + } + os << " return success();\n"; os << "}\n\n"; } @@ -330,6 +361,7 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, utilsString; raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), serFn(serFnString), deserFn(deserFnString), utils(utilsString); + auto attrClass = recordKeeper.getClass("Attr"); declareOpcodeFn(utils); initDispatchSerializationFn(dSerFn); @@ -341,9 +373,9 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, } Operator op(def); emitGetOpcodeFunction(def, op, utils); - emitSerializationFunction(def, op, serFn); + emitSerializationFunction(attrClass, def, op, serFn); emitSerializationDispatch(op, dSerFn); - emitDeserializationFunction(def, op, deserFn); + emitDeserializationFunction(attrClass, def, op, deserFn); emitDeserializationDispatch(op, def, dDesFn); } finalizeDispatchSerializationFn(dSerFn); @@ -378,21 +410,6 @@ static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) { "SymbolizeFnTy symbolizeEnum();\n"; } -std::string convertSnakeCase(llvm::StringRef inputString) { - std::string snakeCase; - for (auto c : inputString) { - if (c >= 'A' && c <= 'Z') { - if (!snakeCase.empty()) { - snakeCase.push_back('_'); - } - snakeCase.push_back((c - 'A') + 'a'); - } else { - snakeCase.push_back(c); - } - } - return snakeCase; -} - static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, raw_ostream &os) { auto enumName = enumAttr.getEnumClassName(); @@ -400,7 +417,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, << " {\n"; os << " " << formatv("static constexpr const char attrName[] = \"{0}\";\n", - convertSnakeCase(enumName)); + mlir::convertToSnakeCase(enumName)); os << " return attrName;\n"; os << "}\n"; } diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py index de177566a67..ac00179ec7a 100755 --- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py +++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py @@ -109,6 +109,28 @@ def split_list_into_sublists(items, offset): return chuncks +def uniquify(lst, equality_fn): + """Returns a list after pruning duplicate elements. + + Arguments: + - lst: List whose elements are to be uniqued. + - equality_fn: Function used to compare equality between elements of the + list. + + Returns: + - A list with all duplicated removed. The order of elements is same as the + original list, with only the first occurence of duplicates retained. + """ + keys = set() + unique_lst = [] + for elem in lst: + key = equality_fn(elem) + if equality_fn(key) not in keys: + unique_lst.append(elem) + keys.add(key) + return unique_lst + + def gen_operand_kind_enum_attr(operand_kind): """Generates the TableGen I32EnumAttr definition for the given operand kind. @@ -123,6 +145,7 @@ def gen_operand_kind_enum_attr(operand_kind): kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) kind_cases = [(case['enumerant'], case['value']) for case in operand_kind['enumerants']] + kind_cases = uniquify(kind_cases, lambda x: x[1]) max_len = max([len(symbol) for (symbol, _) in kind_cases]) # Generate the definition for each enum case