diff --git a/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td index b626836855b..80e62847247 100644 --- a/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/third_party/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -192,7 +192,7 @@ def FCmpPredicate : I64EnumAttr< [FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE, FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD, FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT, - FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE + FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE ]> { let cppNamespace = "mlir::LLVM"; @@ -394,6 +394,32 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { // Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to // work correctly). +def LLVM_AddressOfOp + : LLVM_OneResultOp<"addressof">, + Arguments<(ins SymbolRefAttr:$global_name)> { + let builders = [ + OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, " + "StringRef name, ArrayRef attrs = {}", [{ + result->addAttribute("global_name", builder->getSymbolRefAttr(name)); + result->addAttributes(attrs); + result->addTypes(resType);}]>, + + OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, " + "ArrayRef attrs = {}", [{ + build(builder, result, global.getType().getPointerTo(), global.sym_name(), + attrs);}]> + ]; + + let extraClassDeclaration = [{ + /// Return the llvm.global operation that defined the value referenced here. + GlobalOp getGlobal(); + }]; + + let printer = "printAddressOfOp(p, *this);"; + let parser = "return parseAddressOfOp(parser, result);"; + let verifier = "return ::verify(*this);"; +} + def LLVM_GlobalOp : LLVM_ZeroResultOp<"global">, Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name, diff --git a/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 04651b8cf6e..584d2a84fe9 100644 --- a/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -89,6 +89,9 @@ private: ModuleOp mlirModule; std::unique_ptr llvmModule; + // Mappings between llvm.global definitions and corresponding globals. + llvm::DenseMap globalsMapping; + protected: // Mappings between original and translated values, used for lookups. llvm::StringMap functionMapping; diff --git a/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 378907eff0b..199d40150dc 100644 --- a/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/third_party/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -788,6 +788,49 @@ static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) { return success(); } +//===----------------------------------------------------------------------===// +// Printer, parser and verifier for LLVM::AddressOfOp. +//===----------------------------------------------------------------------===// + +GlobalOp AddressOfOp::getGlobal() { + auto module = getParentOfType(); + assert(module && "unexpected operation outside of a module"); + return module.lookupSymbol(global_name()); +} + +static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) { + *p << op.getOperationName() << " @" << op.global_name(); + p->printOptionalAttrDict(op.getAttrs(), {"global_name"}); + *p << " : " << op.getResult()->getType(); +} + +static ParseResult parseAddressOfOp(OpAsmParser *parser, + OperationState *result) { + Attribute symRef; + Type type; + if (parser->parseAttribute(symRef, "global_name", result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->addTypeToList(type, result->types)) + return failure(); + + if (!symRef.isa()) + return parser->emitError(parser->getNameLoc(), "expected symbol reference"); + return success(); +} + +static LogicalResult verify(AddressOfOp op) { + auto global = op.getGlobal(); + if (!global) + return op.emitOpError("must reference a global defined by 'llvm.global'"); + + if (global.getType().getPointerTo() != op.getResult()->getType()) + return op.emitOpError( + "the type must be a pointer to the type of the referred global"); + + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ConstantOp. //===----------------------------------------------------------------------===// diff --git a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 5e1109bbdd0..7a84eaea0a1 100644 --- a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -247,6 +247,18 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, return success(); } + // Emit addressof. We need to look up the global value referenced by the + // operation and store it in the MLIR-to-LLVM value mapping. This does not + // emit any LLVM instruction. + if (auto addressOfOp = dyn_cast(opInst)) { + LLVM::GlobalOp global = addressOfOp.getGlobal(); + // The verifier should not have allowed this. + assert(global && "referencing an undefined global"); + + valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global); + return success(); + } + return opInst.emitError("unsupported or non-LLVM operation: ") << opInst.getName(); } @@ -290,21 +302,23 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { // Create named global variables that correspond to llvm.global definitions. void ModuleTranslation::convertGlobals() { for (auto op : mlirModule.getOps()) { + llvm::Constant *cst; + llvm::Type *type; // String attributes are treated separately because they cannot appear as // in-function constants and are thus not supported by getLLVMConstant. if (auto strAttr = op.value().dyn_cast()) { - llvm::Constant *cst = llvm::ConstantDataArray::getString( + cst = llvm::ConstantDataArray::getString( llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); - new llvm::GlobalVariable(*llvmModule, cst->getType(), op.constant(), - llvm::GlobalValue::InternalLinkage, cst, - op.sym_name()); - return; + type = cst->getType(); + } else { + type = op.getType().getUnderlyingType(); + cst = getLLVMConstant(type, op.value(), op.getLoc()); } - llvm::Type *type = op.getType().getUnderlyingType(); - new llvm::GlobalVariable( - *llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage, - getLLVMConstant(type, op.value(), op.getLoc()), op.sym_name()); + auto *var = new llvm::GlobalVariable(*llvmModule, type, op.constant(), + llvm::GlobalValue::InternalLinkage, + cst, op.sym_name()); + globalsMapping.try_emplace(op, var); } }