NFC: Add support for parsing attributes programmatically via mlir::parseAttribute.

This matches the behavior of the public mlir::parseType, and even uses the internal implementation.

PiperOrigin-RevId: 275989777
Change-Id: If3c060a39e195ad565bd831327c702b0c02be5a7
This commit is contained in:
River Riddle 2019-10-21 21:34:21 -07:00 committed by TensorFlower Gardener
parent 5f0ed20652
commit a9c1ef68bf
2 changed files with 60 additions and 9 deletions

View File

@ -31,6 +31,7 @@ class StringRef;
} // end namespace llvm
namespace mlir {
class Attribute;
class Location;
class MLIRContext;
class OwningModuleRef;
@ -61,6 +62,24 @@ OwningModuleRef parseSourceFile(llvm::StringRef filename,
OwningModuleRef parseSourceString(llvm::StringRef moduleStr,
MLIRContext *context);
/// This parses a single MLIR attribute to an MLIR context if it was valid. If
/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
/// `attrStr`. If the passed `attrStr` has additional tokens that were not part
/// of the type, an error is emitted.
// TODO(ntv) Improve diagnostic reporting.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context);
Attribute parseAttribute(llvm::StringRef attrStr, Type type);
/// This parses a single MLIR attribute to an MLIR context if it was valid. If
/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
/// `attrStr`. The number of characters of `attrStr` parsed in the process is
/// returned in `numRead`.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
size_t &numRead);
Attribute parseAttribute(llvm::StringRef attrStr, Type type, size_t &numRead);
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error message is emitted through a new SourceMgrDiagnosticHandler
/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping

View File

@ -4333,28 +4333,60 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
return parseSourceFile(sourceMgr, context);
}
Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context,
size_t &numRead) {
/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
/// parsing failed, nullptr is returned. The number of bytes read from the input
/// string is returned in 'numRead'.
template <typename T, typename ParserFn>
static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
size_t &numRead, ParserFn &&parserFn) {
SourceMgr sourceMgr;
auto memBuffer =
MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>",
/*RequiresNullTerminator=*/false);
auto memBuffer = MemoryBuffer::getMemBuffer(
inputStr, /*BufferName=*/"<mlir_parser_buffer>",
/*RequiresNullTerminator=*/false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
ParserState state(sourceMgr, context);
Parser parser(state);
auto start = parser.getToken().getLoc();
auto ty = parser.parseType();
if (!ty)
return Type();
T symbol = parserFn(parser);
if (!symbol)
return T();
auto end = parser.getToken().getLoc();
numRead = static_cast<size_t>(end.getPointer() - start.getPointer());
return ty;
return symbol;
}
Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context) {
size_t numRead = 0;
return parseAttribute(attrStr, context, numRead);
}
Attribute mlir::parseAttribute(llvm::StringRef attrStr, Type type) {
size_t numRead = 0;
return parseAttribute(attrStr, type, numRead);
}
Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
size_t &numRead) {
return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
return parser.parseAttribute();
});
}
Attribute mlir::parseAttribute(llvm::StringRef attrStr, Type type,
size_t &numRead) {
return parseSymbol<Attribute>(
attrStr, type.getContext(), numRead,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) {
size_t numRead = 0;
return parseType(typeStr, context, numRead);
}
Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context,
size_t &numRead) {
return parseSymbol<Type>(typeStr, context, numRead,
[](Parser &parser) { return parser.parseType(); });
}