diff --git a/third_party/mlir/include/mlir/IR/FunctionSupport.h b/third_party/mlir/include/mlir/IR/FunctionSupport.h index a70013a1caf..192f5dd3342 100644 --- a/third_party/mlir/include/mlir/IR/FunctionSupport.h +++ b/third_party/mlir/include/mlir/IR/FunctionSupport.h @@ -55,15 +55,16 @@ inline ArrayRef getArgAttrs(Operation *op, unsigned index) { /// Callback type for `parseFunctionLikeOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of -/// function arguments and results. -using FuncTypeBuilder = - llvm::function_ref, ArrayRef)>; +/// function arguments and results; in case of error, it may populate the last +/// argument with a message. +using FuncTypeBuilder = llvm::function_ref, + ArrayRef, std::string &)>; /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of /// input and output types. If the builder returns a null type, `result` will -/// not contain the `type` attribute. The caller can then either add the type -/// or use op's verifier to report errors. +/// not contain the `type` attribute. The caller can then add a type, report +/// the error or delegate the reporting to the op's verifier. ParseResult parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, FuncTypeBuilder funcTypeBuilder); diff --git a/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h b/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h index 2f98828b102..55479f22c63 100644 --- a/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/third_party/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -67,6 +67,11 @@ public: /// Array type utilities. LLVMType getArrayElementType(); + /// Function type utilities. + LLVMType getFunctionParamType(unsigned argIdx); + unsigned getFunctionNumParams(); + LLVMType getFunctionResultType(); + /// Pointer type utilities. LLVMType getPointerTo(unsigned addrSpace = 0); LLVMType getPointerElementTy(); diff --git a/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td index 5c013916fd2..9031242dc22 100644 --- a/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -337,6 +337,10 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", }]; let verifier = [{ return ::verify(*this); }]; + let printer = [{ printLLVMFuncOp(p, *this); }]; + let parser = [{ + return impl::parseFunctionLikeOp(parser, result, buildLLVMFunctionType); + }]; } def LLVM_UndefOp : LLVM_OneResultOp<"undef", [NoSideEffect]>, diff --git a/third_party/mlir/lib/IR/Function.cpp b/third_party/mlir/lib/IR/Function.cpp index 106b670cac4..af0edf970ad 100644 --- a/third_party/mlir/lib/IR/Function.cpp +++ b/third_party/mlir/lib/IR/Function.cpp @@ -78,9 +78,8 @@ void FuncOp::build(Builder *builder, OperationState *result, StringRef name, ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) { return impl::parseFunctionLikeOp( parser, result, - [](Builder &builder, ArrayRef argTypes, ArrayRef results) { - return builder.getFunctionType(argTypes, results); - }); + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + std::string &) { return builder.getFunctionType(argTypes, results); }); } void FuncOp::print(OpAsmPrinter *p) { diff --git a/third_party/mlir/lib/IR/FunctionSupport.cpp b/third_party/mlir/lib/IR/FunctionSupport.cpp index 081da758be5..92285e4ba21 100644 --- a/third_party/mlir/lib/IR/FunctionSupport.cpp +++ b/third_party/mlir/lib/IR/FunctionSupport.cpp @@ -110,11 +110,17 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, result->attributes.back().second = builder.getStringAttr(nameAttr.getValue()); // Parse the function signature. + auto signatureLocation = parser->getCurrentLocation(); if (parseFunctionSignature(parser, entryArgs, argTypes, argAttrs, results)) return failure(); - if (auto type = funcTypeBuilder(builder, argTypes, results)) + std::string errorMessage; + if (auto type = funcTypeBuilder(builder, argTypes, results, errorMessage)) result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type)); + else + return parser->emitError(signatureLocation) + << "failed to construct function type" + << (errorMessage.empty() ? "" : ": ") << errorMessage; // If function attributes are present, parse them. if (succeeded(parser->parseOptionalKeyword("attributes"))) diff --git a/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index da46e8d5f0c..1315fdd6bd2 100644 --- a/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -703,7 +703,7 @@ static ParseResult parseConstantOp(OpAsmParser *parser, } //===----------------------------------------------------------------------===// -// Builder and verifier for LLVM::LLVMFuncOp. +// Builder, printer and verifier for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, @@ -726,6 +726,62 @@ void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, result->addAttribute(getArgAttrName(i, argAttrName), argDict); } +// Build an LLVM function type from the given lists of input and output types. +// Returns a null type if any of the types provided are non-LLVM types, or if +// there is more than one output type. +static Type buildLLVMFunctionType(Builder &b, ArrayRef inputs, + ArrayRef outputs, + std::string &errorMessage) { + if (outputs.size() > 1) { + errorMessage = "expected zero or one function result"; + return {}; + } + + // Convert inputs to LLVM types, exit early on error. + SmallVector llvmInputs; + for (auto t : inputs) { + auto llvmTy = t.dyn_cast(); + if (!llvmTy) { + errorMessage = "expected LLVM type for function arguments"; + return {}; + } + llvmInputs.push_back(llvmTy); + } + + // Get the dialect from the input type, if any exist. Look it up in the + // context otherwise. + LLVMDialect *dialect = + llvmInputs.empty() ? b.getContext()->getRegisteredDialect() + : &llvmInputs.front().getDialect(); + + // No output is denoted as "void" in LLVM type system. + LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) + : outputs.front().dyn_cast(); + if (!llvmOutput) { + errorMessage = "expected LLVM type for function results"; + return {}; + } + return LLVMType::getFunctionTy(llvmOutput, llvmInputs, + /*isVarArg=*/false); +} + +// Print the LLVMFuncOp. Collects argument and result types and passes them +// to the trait printer. Drops "void" result since it cannot be parsed back. +static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) { + LLVMType fnType = op.getType(); + SmallVector argTypes; + SmallVector resTypes; + argTypes.reserve(fnType.getFunctionNumParams()); + for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) + argTypes.push_back(fnType.getFunctionParamType(i)); + + LLVMType returnType = fnType.getFunctionResultType(); + if (!returnType.getUnderlyingType()->isVoidTy()) + resTypes.push_back(returnType); + + impl::printFunctionLikeOp(p, op, argTypes, resTypes); +} + // Hook for OpTrait::FunctionLike, called after verifying that the 'type' // attribute is present. This can check for preconditions of the // getNumArguments hook not failing. @@ -914,6 +970,19 @@ LLVMType LLVMType::getArrayElementType() { return get(getContext(), getUnderlyingType()->getArrayElementType()); } +/// Function type utilities. +LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { + return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); +} +unsigned LLVMType::getFunctionNumParams() { + return getUnderlyingType()->getFunctionNumParams(); +} +LLVMType LLVMType::getFunctionResultType() { + return get( + getContext(), + llvm::cast(getUnderlyingType())->getReturnType()); +} + /// Pointer type utilities. LLVMType LLVMType::getPointerTo(unsigned addrSpace) { // Lock access to the dialect as this may modify the LLVM context.