From e207706514c83cd6d7907f557aa1745e18d351a3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2019 14:14:28 -0700 Subject: [PATCH] Add support for (de)serialization of SPIR-V Op Decorations All non-argument attributes specified for an operation are treated as decorations on the result value and (de)serialized using OpDecorate instruction. An error is generated if an attribute is not an argument, and the name doesn't correspond to a Decoration enum. Name of the attributes that represent decoerations are to be the snake-case-ified version of the Decoration name. Add utility methods to convert to snake-case and camel-case. PiperOrigin-RevId: 260792638 --- third_party/mlir/BUILD | 1 + .../include/mlir/Dialect/SPIRV/SPIRVBase.td | 91 +++++++++++++++++++ .../mlir/include/mlir/Support/StringExtras.h | 81 +++++++++++++++++ .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 44 ++++++--- .../SPIRV/Serialization/Deserializer.cpp | 41 +++++++++ .../SPIRV/Serialization/Serializer.cpp | 33 +++++++ .../mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 81 ++++++++++------- .../mlir/utils/spirv/gen_spirv_dialect.py | 23 +++++ 8 files changed, 349 insertions(+), 46 deletions(-) create mode 100644 third_party/mlir/include/mlir/Support/StringExtras.h diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index d2c141bf231..7d47aae2270 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -409,6 +409,7 @@ cc_library( "include/mlir/Support/MathExtras.h", "include/mlir/Support/STLExtras.h", "include/mlir/Support/StorageUniquer.h", + "include/mlir/Support/StringExtras.h", ], includes = ["include"], deps = [ diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index a12f3390125..a820c11dbdb 100644 --- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -196,6 +196,97 @@ def SPV_AddressingModelAttr : let cppNamespace = "::mlir::spirv"; } +def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>; +def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>; +def SPV_D_Block : I32EnumAttrCase<"Block", 2>; +def SPV_D_BufferBlock : I32EnumAttrCase<"BufferBlock", 3>; +def SPV_D_RowMajor : I32EnumAttrCase<"RowMajor", 4>; +def SPV_D_ColMajor : I32EnumAttrCase<"ColMajor", 5>; +def SPV_D_ArrayStride : I32EnumAttrCase<"ArrayStride", 6>; +def SPV_D_MatrixStride : I32EnumAttrCase<"MatrixStride", 7>; +def SPV_D_GLSLShared : I32EnumAttrCase<"GLSLShared", 8>; +def SPV_D_GLSLPacked : I32EnumAttrCase<"GLSLPacked", 9>; +def SPV_D_CPacked : I32EnumAttrCase<"CPacked", 10>; +def SPV_D_BuiltIn : I32EnumAttrCase<"BuiltIn", 11>; +def SPV_D_NoPerspective : I32EnumAttrCase<"NoPerspective", 13>; +def SPV_D_Flat : I32EnumAttrCase<"Flat", 14>; +def SPV_D_Patch : I32EnumAttrCase<"Patch", 15>; +def SPV_D_Centroid : I32EnumAttrCase<"Centroid", 16>; +def SPV_D_Sample : I32EnumAttrCase<"Sample", 17>; +def SPV_D_Invariant : I32EnumAttrCase<"Invariant", 18>; +def SPV_D_Restrict : I32EnumAttrCase<"Restrict", 19>; +def SPV_D_Aliased : I32EnumAttrCase<"Aliased", 20>; +def SPV_D_Volatile : I32EnumAttrCase<"Volatile", 21>; +def SPV_D_Constant : I32EnumAttrCase<"Constant", 22>; +def SPV_D_Coherent : I32EnumAttrCase<"Coherent", 23>; +def SPV_D_NonWritable : I32EnumAttrCase<"NonWritable", 24>; +def SPV_D_NonReadable : I32EnumAttrCase<"NonReadable", 25>; +def SPV_D_Uniform : I32EnumAttrCase<"Uniform", 26>; +def SPV_D_UniformId : I32EnumAttrCase<"UniformId", 27>; +def SPV_D_SaturatedConversion : I32EnumAttrCase<"SaturatedConversion", 28>; +def SPV_D_Stream : I32EnumAttrCase<"Stream", 29>; +def SPV_D_Location : I32EnumAttrCase<"Location", 30>; +def SPV_D_Component : I32EnumAttrCase<"Component", 31>; +def SPV_D_Index : I32EnumAttrCase<"Index", 32>; +def SPV_D_Binding : I32EnumAttrCase<"Binding", 33>; +def SPV_D_DescriptorSet : I32EnumAttrCase<"DescriptorSet", 34>; +def SPV_D_Offset : I32EnumAttrCase<"Offset", 35>; +def SPV_D_XfbBuffer : I32EnumAttrCase<"XfbBuffer", 36>; +def SPV_D_XfbStride : I32EnumAttrCase<"XfbStride", 37>; +def SPV_D_FuncParamAttr : I32EnumAttrCase<"FuncParamAttr", 38>; +def SPV_D_FPRoundingMode : I32EnumAttrCase<"FPRoundingMode", 39>; +def SPV_D_FPFastMathMode : I32EnumAttrCase<"FPFastMathMode", 40>; +def SPV_D_LinkageAttributes : I32EnumAttrCase<"LinkageAttributes", 41>; +def SPV_D_NoContraction : I32EnumAttrCase<"NoContraction", 42>; +def SPV_D_InputAttachmentIndex : I32EnumAttrCase<"InputAttachmentIndex", 43>; +def SPV_D_Alignment : I32EnumAttrCase<"Alignment", 44>; +def SPV_D_MaxByteOffset : I32EnumAttrCase<"MaxByteOffset", 45>; +def SPV_D_AlignmentId : I32EnumAttrCase<"AlignmentId", 46>; +def SPV_D_MaxByteOffsetId : I32EnumAttrCase<"MaxByteOffsetId", 47>; +def SPV_D_NoSignedWrap : I32EnumAttrCase<"NoSignedWrap", 4469>; +def SPV_D_NoUnsignedWrap : I32EnumAttrCase<"NoUnsignedWrap", 4470>; +def SPV_D_ExplicitInterpAMD : I32EnumAttrCase<"ExplicitInterpAMD", 4999>; +def SPV_D_OverrideCoverageNV : I32EnumAttrCase<"OverrideCoverageNV", 5248>; +def SPV_D_PassthroughNV : I32EnumAttrCase<"PassthroughNV", 5250>; +def SPV_D_ViewportRelativeNV : I32EnumAttrCase<"ViewportRelativeNV", 5252>; +def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>; +def SPV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271>; +def SPV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272>; +def SPV_D_PerTaskNV : I32EnumAttrCase<"PerTaskNV", 5273>; +def SPV_D_PerVertexNV : I32EnumAttrCase<"PerVertexNV", 5285>; +def SPV_D_NonUniformEXT : I32EnumAttrCase<"NonUniformEXT", 5300>; +def SPV_D_RestrictPointerEXT : I32EnumAttrCase<"RestrictPointerEXT", 5355>; +def SPV_D_AliasedPointerEXT : I32EnumAttrCase<"AliasedPointerEXT", 5356>; +def SPV_D_CounterBuffer : I32EnumAttrCase<"CounterBuffer", 5634>; +def SPV_D_UserSemantic : I32EnumAttrCase<"UserSemantic", 5635>; +def SPV_D_UserTypeGOOGLE : I32EnumAttrCase<"UserTypeGOOGLE", 5636>; + +def SPV_DecorationAttr : + I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [ + SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock, + SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride, + SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn, + SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample, + SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant, + SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform, + SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location, + SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset, + SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode, + SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction, + SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset, + SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap, + SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV, + SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV, + SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV, + SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniformEXT, + SPV_D_RestrictPointerEXT, SPV_D_AliasedPointerEXT, SPV_D_CounterBuffer, + SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE + ]> { + let returnType = "::mlir::spirv::Decoration"; + let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())"; + let cppNamespace = "::mlir::spirv"; +} + def SPV_D_1D : I32EnumAttrCase<"1D", 0>; def SPV_D_2D : I32EnumAttrCase<"2D", 1>; def SPV_D_3D : I32EnumAttrCase<"3D", 2>; diff --git a/third_party/mlir/include/mlir/Support/StringExtras.h b/third_party/mlir/include/mlir/Support/StringExtras.h new file mode 100644 index 00000000000..a5ec73275ff --- /dev/null +++ b/third_party/mlir/include/mlir/Support/StringExtras.h @@ -0,0 +1,81 @@ +//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains string utility functions used within MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_STRINGEXTRAS_H +#define MLIR_SUPPORT_STRINGEXTRAS_H + +#include "llvm/ADT/StringExtras.h" + +namespace mlir { +/// Converts a string to snake-case from camel-case by replacing all uppercase +/// letters with '_' followed by the letter in lowercase, except if the +/// uppercase letter is the first character of the string. +inline std::string convertToSnakeCase(llvm::StringRef input) { + std::string snakeCase; + snakeCase.reserve(input.size()); + for (auto c : input) { + if (std::isupper(c)) { + if (!snakeCase.empty() && snakeCase.back() != '_') { + snakeCase.push_back('_'); + } + snakeCase.push_back(llvm::toLower(c)); + } else { + snakeCase.push_back(c); + } + } + return snakeCase; +} + +/// Converts a string from camel-case to snake_case by replacing all occurences +/// of '_' followed by a lowercase letter with the letter in +/// uppercase. Optionally allow capitalization of the first letter (if it is a +/// lowercase letter) +inline std::string convertToCamelCase(llvm::StringRef input, + bool capitalizeFirst = false) { + if (input.empty()) { + return ""; + } + std::string output; + output.reserve(input.size()); + size_t pos = 0; + if (capitalizeFirst && std::islower(input[pos])) { + output.push_back(llvm::toUpper(input[pos])); + pos++; + } + while (pos < input.size()) { + auto cur = input[pos]; + if (cur == '_') { + if (pos && (pos + 1 < input.size())) { + if (std::islower(input[pos + 1])) { + output.push_back(llvm::toUpper(input[pos + 1])); + pos += 2; + continue; + } + } + } + output.push_back(cur); + pos++; + } + return output; +} +} // namespace mlir + +#endif // MLIR_SUPPORT_STRINGEXTRAS_H diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ae5752a396e..05a1746bccd 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -26,13 +26,12 @@ #include "mlir/IR/Function.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/StringExtras.h" using namespace mlir; // TODO(antiagainst): generate these strings using ODS. static constexpr const char kAlignmentAttrName[] = "alignment"; -static constexpr const char kBindingAttrName[] = "binding"; -static constexpr const char kDescriptorSetAttrName[] = "descriptor_set"; static constexpr const char kIndicesAttrName[] = "indices"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; @@ -67,8 +66,7 @@ static LogicalResult extractValueFromConstOp(Operation *op, } template -static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, - OperationState *state) { +static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) { Attribute attrVal; SmallVector attr; auto loc = parser->getCurrentLocation(); @@ -89,6 +87,15 @@ static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, << " attribute specification: " << attrVal; } value = attrOptional.getValue(); + return success(); +} + +template +static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, + OperationState *state) { + if (parseEnumAttribute(value, parser)) { + return failure(); + } state->addAttribute( spirv::attributeName(), parser->getBuilder().getI32IntegerAttr(bitwiseCast(value))); @@ -601,7 +608,7 @@ static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) { spirv::StorageClass storageClass; OpAsmParser::OperandType ptrInfo; Type elementType; - if (parseEnumAttribute(storageClass, parser, state) || + if (parseEnumAttribute(storageClass, parser) || parser->parseOperand(ptrInfo) || parseMemoryAccessAttributes(parser, state) || parser->parseOptionalAttributeDict(state->attributes) || @@ -813,7 +820,7 @@ static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) { SmallVector operandInfo; auto loc = parser->getCurrentLocation(); Type elementType; - if (parseEnumAttribute(storageClass, parser, state) || + if (parseEnumAttribute(storageClass, parser) || parser->parseOperandList(operandInfo, 2) || parseMemoryAccessAttributes(parser, state) || parser->parseColon() || parser->parseType(elementType)) { @@ -873,13 +880,17 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { // Parse optional descriptor binding Attribute set, binding; + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); if (succeeded(parser->parseOptionalKeyword("bind"))) { Type i32Type = parser->getBuilder().getIntegerType(32); if (parser->parseLParen() || - parser->parseAttribute(set, i32Type, kDescriptorSetAttrName, + parser->parseAttribute(set, i32Type, descriptorSetName, state->attributes) || parser->parseComma() || - parser->parseAttribute(binding, i32Type, kBindingAttrName, + parser->parseAttribute(binding, i32Type, bindingName, state->attributes) || parser->parseRParen()) return failure(); @@ -931,12 +942,17 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { } // Print optional descriptor binding - auto set = varOp.getAttrOfType(kDescriptorSetAttrName); - auto binding = varOp.getAttrOfType(kBindingAttrName); - if (set && binding) { - elidedAttrs.push_back(kDescriptorSetAttrName); - elidedAttrs.push_back(kBindingAttrName); - *printer << " bind(" << set.getInt() << ", " << binding.getInt() << ")"; + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSet = varOp.getAttrOfType(descriptorSetName); + auto binding = varOp.getAttrOfType(bindingName); + if (descriptorSet && binding) { + elidedAttrs.push_back(descriptorSetName); + elidedAttrs.push_back(bindingName); + *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() + << ")"; } printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 4c35f6adf91..2ca8f4578e3 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StringExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" @@ -80,6 +81,9 @@ private: /// Process SPIR-V OpName with `operands` LogicalResult processName(ArrayRef operands); + /// Method to process an OpDecorate instruction. + LogicalResult processDecoration(ArrayRef words); + /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. /// This method processes each instruction inside the function and dispatches @@ -196,6 +200,9 @@ private: // Result to name mapping. DenseMap nameMap; + // Result to decorations mapping. + DenseMap decorations; + // List of instructions that are processed in a defered fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function @@ -285,6 +292,37 @@ LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processDecoration(ArrayRef words) { + // TODO : This function should also be auto-generated. For now, since only a + // few decorations are processed/handled in a meaningful manner, going with a + // manual implementation. + if (words.size() < 2) { + return emitError( + unknownLoc, "OpDecorate must have at least result and Decoration"); + } + auto decorationName = + stringifyDecoration(static_cast(words[1])); + if (decorationName.empty()) { + return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; + } + auto attrName = convertToSnakeCase(decorationName); + switch (static_cast(words[1])) { + case spirv::Decoration::DescriptorSet: + case spirv::Decoration::Binding: + if (words.size() != 3) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName << " needs a single integer literal"; + } + decorations[words[0]].set( + opBuilder.getIdentifier(attrName), + opBuilder.getI32IntegerAttr(static_cast(words[2]))); + break; + default: + return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; + } + return success(); +} + LogicalResult Deserializer::processFunction(ArrayRef operands) { // Get the result type if (operands.size() != 4) { @@ -830,6 +868,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processConstantBool(false, operands); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); + case spirv::Opcode::OpDecorate: + return processDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); default: @@ -839,6 +879,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, } namespace { + template <> LogicalResult Deserializer::processOp(ArrayRef words) { diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 7030bd9e71b..35c4088fa0a 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StringExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/raw_ostream.h" @@ -127,6 +128,10 @@ private: /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); + /// Process attributes that translate to decorations on the result + LogicalResult processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr); + //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// @@ -319,6 +324,34 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { return failure(); } +LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr) { + auto attrName = attr.first.strref(); + auto decorationName = mlir::convertToCamelCase(attrName, true); + auto decoration = spirv::symbolizeDecoration(decorationName); + if (!decoration) { + return emitError( + loc, "non-argument attributes expected to have snake-case-ified " + "decoration name, unhandled attribute with name : ") + << attrName; + } + SmallVector args; + args.push_back(resultID); + args.push_back(static_cast(decoration.getValue())); + switch (decoration.getValue()) { + case spirv::Decoration::DescriptorSet: + case spirv::Decoration::Binding: + if (auto intAttr = attr.second.dyn_cast()) { + args.push_back(intAttr.getValue().getZExtValue()); + break; + } + return emitError(loc, "expected integer attribute for ") << attrName; + default: + return emitError(loc, "unhandled decoration ") << decorationName; + } + return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); +} + LogicalResult Serializer::processFuncOp(FuncOp op) { uint32_t fnTypeID = 0; // Generate type of the function. diff --git a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 0c177204712..75da5e7996c 100644 --- a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -20,9 +20,11 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -39,6 +41,8 @@ using llvm::raw_string_ostream; using llvm::Record; using llvm::RecordKeeper; using llvm::SMLoc; +using llvm::StringRef; +using llvm::Twine; using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; using mlir::tblgen::NamedAttribute; @@ -90,7 +94,8 @@ static void emitAttributeSerialization(const Attribute &attr, os << " }\n"; } -static void emitSerializationFunction(const Record *record, const Operator &op, +static void emitSerializationFunction(const Record *attrClass, + const Record *record, const Operator &op, raw_ostream &os) { // If the record has 'autogenSerialization' set to 0, nothing to do if (!record->getValueAsBit("autogenSerialization")) { @@ -101,21 +106,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op, op.getQualCppClassName()) << " {\n"; os << " SmallVector operands;\n"; + os << " SmallVector elidedAttrs;\n"; // Serialize result information if (op.getNumResults() == 1) { - os << " {\n"; - os << " uint32_t typeID = 0;\n"; - os << " if (failed(processType(op.getLoc(), " - "op.getResult()->getType(), typeID))) {\n"; - os << " return failure();\n"; - os << " }\n"; - os << " operands.push_back(typeID);\n"; - /// Create an SSA result for the op - os << " auto resultID = getNextID();\n"; - os << " valueIDMap[op.getResult()] = resultID;\n"; - os << " operands.push_back(resultID);\n"; + os << " uint32_t resultTypeID = 0;\n"; + os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) " + "{\n"; + os << " return failure();\n"; os << " }\n"; + os << " operands.push_back(resultTypeID);\n"; + // Create an SSA result for the op + os << " auto resultID = getNextID();\n"; + os << " valueIDMap[op.getResult()] = resultID;\n"; + os << " operands.push_back(resultID);\n"; } else if (op.getNumResults() != 0) { PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result"); } @@ -140,6 +144,7 @@ static void emitSerializationFunction(const Record *record, const Operator &op, emitAttributeSerialization( (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), record->getLoc(), "op", "operands", attr->name, os); + os << " elidedAttrs.push_back(\"" << attr->name << "\");\n"; } os << " }\n"; } @@ -147,6 +152,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op, os << formatv(" encodeInstructionInto(" "functions, spirv::getOpcode<{0}>(), operands);\n", op.getQualCppClassName()); + + if (op.getNumResults() == 1) { + // All non-argument attributes translated into OpDecorate instruction + os << " for (auto attr : op.getAttrs()) {\n"; + os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return " + "attr.first.is(elided); })) {\n"; + os << " continue;\n"; + os << " }\n"; + os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n"; + os << " return failure();"; + os << " }\n"; + os << " }\n"; + } + os << " return success();\n"; os << "}\n\n"; } @@ -196,7 +215,8 @@ static void emitAttributeDeserialization( } } -static void emitDeserializationFunction(const Record *record, +static void emitDeserializationFunction(const Record *attrClass, + const Record *record, const Operator &op, raw_ostream &os) { // If the record has 'autogenSerialization' set to 0, nothing to do if (!record->getValueAsBit("autogenSerialization")) { @@ -292,8 +312,19 @@ static void emitDeserializationFunction(const Record *record, "operands, attributes); (void)op;\n", op.getQualCppClassName()); if (hasResult) { - os << " valueMap[valueID] = op.getResult();\n"; + os << " valueMap[valueID] = op.getResult();\n\n"; } + + // Import decorations parsed + if (op.getNumResults() == 1) { + os << " if (decorations.count(valueID)) {\n"; + os << " auto decorationAttrs = decorations[valueID];\n"; + os << " for (auto attr : decorationAttrs.getAttrs()) {\n"; + os << " op.setAttr(attr.first, attr.second);\n"; + os << " }\n"; + os << " }\n"; + } + os << " return success();\n"; os << "}\n\n"; } @@ -330,6 +361,7 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, utilsString; raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), serFn(serFnString), deserFn(deserFnString), utils(utilsString); + auto attrClass = recordKeeper.getClass("Attr"); declareOpcodeFn(utils); initDispatchSerializationFn(dSerFn); @@ -341,9 +373,9 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, } Operator op(def); emitGetOpcodeFunction(def, op, utils); - emitSerializationFunction(def, op, serFn); + emitSerializationFunction(attrClass, def, op, serFn); emitSerializationDispatch(op, dSerFn); - emitDeserializationFunction(def, op, deserFn); + emitDeserializationFunction(attrClass, def, op, deserFn); emitDeserializationDispatch(op, def, dDesFn); } finalizeDispatchSerializationFn(dSerFn); @@ -378,21 +410,6 @@ static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) { "SymbolizeFnTy symbolizeEnum();\n"; } -std::string convertSnakeCase(llvm::StringRef inputString) { - std::string snakeCase; - for (auto c : inputString) { - if (c >= 'A' && c <= 'Z') { - if (!snakeCase.empty()) { - snakeCase.push_back('_'); - } - snakeCase.push_back((c - 'A') + 'a'); - } else { - snakeCase.push_back(c); - } - } - return snakeCase; -} - static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, raw_ostream &os) { auto enumName = enumAttr.getEnumClassName(); @@ -400,7 +417,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, << " {\n"; os << " " << formatv("static constexpr const char attrName[] = \"{0}\";\n", - convertSnakeCase(enumName)); + mlir::convertToSnakeCase(enumName)); os << " return attrName;\n"; os << "}\n"; } diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py index de177566a67..ac00179ec7a 100755 --- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py +++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py @@ -109,6 +109,28 @@ def split_list_into_sublists(items, offset): return chuncks +def uniquify(lst, equality_fn): + """Returns a list after pruning duplicate elements. + + Arguments: + - lst: List whose elements are to be uniqued. + - equality_fn: Function used to compare equality between elements of the + list. + + Returns: + - A list with all duplicated removed. The order of elements is same as the + original list, with only the first occurence of duplicates retained. + """ + keys = set() + unique_lst = [] + for elem in lst: + key = equality_fn(elem) + if equality_fn(key) not in keys: + unique_lst.append(elem) + keys.add(key) + return unique_lst + + def gen_operand_kind_enum_attr(operand_kind): """Generates the TableGen I32EnumAttr definition for the given operand kind. @@ -123,6 +145,7 @@ def gen_operand_kind_enum_attr(operand_kind): kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) kind_cases = [(case['enumerant'], case['value']) for case in operand_kind['enumerants']] + kind_cases = uniquify(kind_cases, lambda x: x[1]) max_len = max([len(symbol) for (symbol, _) in kind_cases]) # Generate the definition for each enum case