From 21e4e8b991fd2d408b1a745954d4d3e895396fd7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 5 Nov 2019 17:58:16 -0800 Subject: [PATCH] Add (parse|print)OptionalAttrDictWithKeyword hooks to simplify parsing attribute dictionaries with regions. Many operations with regions add an additional 'attributes' prefix when printing the attribute dictionary to differentiate it from the region body. This leads to duplicated logic for detecting when to actually print the attribute dictionary. PiperOrigin-RevId: 278747681 Change-Id: I14cd3c5c297a98e756f84a3fb4a24d440deebd0b --- .../mlir/include/mlir/IR/OpImplementation.h | 11 ++++++ .../mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 20 ++--------- third_party/mlir/lib/IR/AsmPrinter.cpp | 34 +++++++++++-------- third_party/mlir/lib/IR/FunctionSupport.cpp | 5 ++- third_party/mlir/lib/IR/Module.cpp | 15 +++----- third_party/mlir/lib/Parser/Parser.cpp | 11 +++++- 6 files changed, 51 insertions(+), 45 deletions(-) diff --git a/third_party/mlir/include/mlir/IR/OpImplementation.h b/third_party/mlir/include/mlir/IR/OpImplementation.h index 6d41f8ffc1b..a167a7df14c 100644 --- a/third_party/mlir/include/mlir/IR/OpImplementation.h +++ b/third_party/mlir/include/mlir/IR/OpImplementation.h @@ -78,6 +78,12 @@ public: virtual void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) = 0; + /// If the specified operation has attributes, print out an attribute + /// dictionary prefixed with 'attributes'. + virtual void + printOptionalAttrDictWithKeyword(ArrayRef attrs, + ArrayRef elidedAttrs = {}) = 0; + /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(Operation *op) = 0; @@ -342,6 +348,11 @@ public: virtual ParseResult parseOptionalAttrDict(SmallVectorImpl &result) = 0; + /// Parse a named dictionary into 'result' if the `attributes` keyword is + /// present. + virtual ParseResult + parseOptionalAttrDictWithKeyword(SmallVectorImpl &result) = 0; + //===--------------------------------------------------------------------===// // Identifier Parsing //===--------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index e8176191a8d..778e7cb3b24 100644 --- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1628,10 +1628,8 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - if (succeeded(parser.parseOptionalKeyword("attributes"))) { - if (parser.parseOptionalAttrDict(state.attributes)) - return failure(); - } + if (parser.parseOptionalAttrDictWithKeyword(state.attributes)) + return failure(); spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location); return success(); @@ -1657,19 +1655,7 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); - - bool printAttrDict = - elidedAttrs.size() != 2 || - llvm::any_of(op->getAttrs(), [&addressingModelAttrName, - &memoryModelAttrName](NamedAttribute attr) { - return attr.first != addressingModelAttrName && - attr.first != memoryModelAttrName; - }); - - if (printAttrDict) { - printer << " attributes"; - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); - } + printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs); } static LogicalResult verify(spirv::ModuleOp moduleOp) { diff --git a/third_party/mlir/lib/IR/AsmPrinter.cpp b/third_party/mlir/lib/IR/AsmPrinter.cpp index 0e6b7882e14..af958c8e61f 100644 --- a/third_party/mlir/lib/IR/AsmPrinter.cpp +++ b/third_party/mlir/lib/IR/AsmPrinter.cpp @@ -421,7 +421,8 @@ public: protected: void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}); + ArrayRef elidedAttrs = {}, + bool withKeyword = false); void printTrailingLocation(Location loc); void printLocationInternal(LocationAttr loc, bool pretty = false); void printDenseElementsAttr(DenseElementsAttr attr); @@ -1327,27 +1328,26 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { //===----------------------------------------------------------------------===// void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs) { + ArrayRef elidedAttrs, + bool withKeyword) { // If there are no attributes, then there is nothing to be done. if (attrs.empty()) return; // Filter out any attributes that shouldn't be included. - SmallVector filteredAttrs; - for (auto attr : attrs) { - // If the caller has requested that this attribute be ignored, then drop it. - if (llvm::any_of(elidedAttrs, - [&](StringRef elided) { return attr.first.is(elided); })) - continue; - - // Otherwise add it to our filteredAttrs list. - filteredAttrs.push_back(attr); - } + SmallVector filteredAttrs( + llvm::make_filter_range(attrs, [&](NamedAttribute attr) { + return !llvm::is_contained(elidedAttrs, attr.first.strref()); + })); // If there are no attributes left to print after filtering, then we're done. if (filteredAttrs.empty()) return; + // Print the 'attributes' keyword if necessary. + if (withKeyword) + os << " attributes "; + // Otherwise, print them all out in braces. os << " {"; interleaveComma(filteredAttrs, [&](NamedAttribute attr) { @@ -1389,8 +1389,14 @@ public: void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { - return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); - }; + ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); + } + void printOptionalAttrDictWithKeyword( + ArrayRef attrs, + ArrayRef elidedAttrs = {}) override { + ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, + /*withKeyword=*/true); + } enum { nameSentinel = ~0U }; diff --git a/third_party/mlir/lib/IR/FunctionSupport.cpp b/third_party/mlir/lib/IR/FunctionSupport.cpp index aa9965e89d1..d1ba2d30fa1 100644 --- a/third_party/mlir/lib/IR/FunctionSupport.cpp +++ b/third_party/mlir/lib/IR/FunctionSupport.cpp @@ -183,9 +183,8 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, << (errorMessage.empty() ? "" : ": ") << errorMessage; // If function attributes are present, parse them. - if (succeeded(parser.parseOptionalKeyword("attributes"))) - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); // Add the attributes to the function arguments. SmallString<8> attrNameBuf; diff --git a/third_party/mlir/lib/IR/Module.cpp b/third_party/mlir/lib/IR/Module.cpp index 3a08ee3bf27..f5cc98e39cf 100644 --- a/third_party/mlir/lib/IR/Module.cpp +++ b/third_party/mlir/lib/IR/Module.cpp @@ -48,9 +48,8 @@ ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) { result.attributes); // If module attributes are present, parse them. - if (succeeded(parser.parseOptionalKeyword("attributes"))) - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); // Parse the module body. auto *body = result.addRegion(); @@ -65,18 +64,14 @@ ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) { void ModuleOp::print(OpAsmPrinter &p) { p << "module"; - Optional name = getName(); - if (name) { + if (Optional name = getName()) { p << ' '; p.printSymbolName(*name); } // Print the module attributes. - auto attrs = getAttrs(); - if (!attrs.empty() && !(attrs.size() == 1 && name)) { - p << " attributes"; - p.printOptionalAttrDict(attrs, {mlir::SymbolTable::getSymbolAttrName()}); - } + p.printOptionalAttrDictWithKeyword(getAttrs(), + {mlir::SymbolTable::getSymbolAttrName()}); // Print the region. p.printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, diff --git a/third_party/mlir/lib/Parser/Parser.cpp b/third_party/mlir/lib/Parser/Parser.cpp index 3a45933db87..35c694b6a43 100644 --- a/third_party/mlir/lib/Parser/Parser.cpp +++ b/third_party/mlir/lib/Parser/Parser.cpp @@ -1533,7 +1533,7 @@ Attribute Parser::parseAttribute(Type type) { /// ParseResult Parser::parseAttributeDict(SmallVectorImpl &attributes) { - if (!consumeIf(Token::l_brace)) + if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) return failure(); auto parseElt = [&]() -> ParseResult { @@ -3874,6 +3874,15 @@ public: return parser.parseAttributeDict(result); } + /// Parse a named dictionary into 'result' if the `attributes` keyword is + /// present. + ParseResult parseOptionalAttrDictWithKeyword( + SmallVectorImpl &result) override { + if (failed(parseOptionalKeyword("attributes"))) + return success(); + return parser.parseAttributeDict(result); + } + //===--------------------------------------------------------------------===// // Identifier Parsing //===--------------------------------------------------------------------===//