LLVM dialect: introduce llvm.addressof to access globals

This instruction is a local counterpart of llvm.global that takes a symbol
reference to a global and produces an SSA value containing the pointer to it.
Used in combination, these two operations allow one to use globals with other
operations expecting SSA values.  At a cost of IR indirection, we make sure the
functions don't implicitly capture the surrounding SSA values and remain
suitable for parallel processing.

PiperOrigin-RevId: 262908622
This commit is contained in:
A. Unique TensorFlower 2019-08-12 06:10:29 -07:00 committed by TensorFlower Gardener
parent c4d152c612
commit c005054786
4 changed files with 96 additions and 10 deletions

View File

@ -192,7 +192,7 @@ def FCmpPredicate : I64EnumAttr<
[FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE,
FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD,
FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
]> {
let cppNamespace = "mlir::LLVM";
@ -394,6 +394,32 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
// Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to
// work correctly).
def LLVM_AddressOfOp
: LLVM_OneResultOp<"addressof">,
Arguments<(ins SymbolRefAttr:$global_name)> {
let builders = [
OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, "
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
result->addAttribute("global_name", builder->getSymbolRefAttr(name));
result->addAttributes(attrs);
result->addTypes(resType);}]>,
OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, "
"ArrayRef<NamedAttribute> attrs = {}", [{
build(builder, result, global.getType().getPointerTo(), global.sym_name(),
attrs);}]>
];
let extraClassDeclaration = [{
/// Return the llvm.global operation that defined the value referenced here.
GlobalOp getGlobal();
}];
let printer = "printAddressOfOp(p, *this);";
let parser = "return parseAddressOfOp(parser, result);";
let verifier = "return ::verify(*this);";
}
def LLVM_GlobalOp
: LLVM_ZeroResultOp<"global">,
Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,

View File

@ -89,6 +89,9 @@ private:
ModuleOp mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
// Mappings between llvm.global definitions and corresponding globals.
llvm::DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
protected:
// Mappings between original and translated values, used for lookups.
llvm::StringMap<llvm::Function *> functionMapping;

View File

@ -788,6 +788,49 @@ static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
return success();
}
//===----------------------------------------------------------------------===//
// Printer, parser and verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() {
auto module = getParentOfType<ModuleOp>();
assert(module && "unexpected operation outside of a module");
return module.lookupSymbol<LLVM::GlobalOp>(global_name());
}
static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) {
*p << op.getOperationName() << " @" << op.global_name();
p->printOptionalAttrDict(op.getAttrs(), {"global_name"});
*p << " : " << op.getResult()->getType();
}
static ParseResult parseAddressOfOp(OpAsmParser *parser,
OperationState *result) {
Attribute symRef;
Type type;
if (parser->parseAttribute(symRef, "global_name", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->addTypeToList(type, result->types))
return failure();
if (!symRef.isa<SymbolRefAttr>())
return parser->emitError(parser->getNameLoc(), "expected symbol reference");
return success();
}
static LogicalResult verify(AddressOfOp op) {
auto global = op.getGlobal();
if (!global)
return op.emitOpError("must reference a global defined by 'llvm.global'");
if (global.getType().getPointerTo() != op.getResult()->getType())
return op.emitOpError(
"the type must be a pointer to the type of the referred global");
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//

View File

@ -247,6 +247,18 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
return success();
}
// Emit addressof. We need to look up the global value referenced by the
// operation and store it in the MLIR-to-LLVM value mapping. This does not
// emit any LLVM instruction.
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
LLVM::GlobalOp global = addressOfOp.getGlobal();
// The verifier should not have allowed this.
assert(global && "referencing an undefined global");
valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
return success();
}
return opInst.emitError("unsupported or non-LLVM operation: ")
<< opInst.getName();
}
@ -290,21 +302,23 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
// Create named global variables that correspond to llvm.global definitions.
void ModuleTranslation::convertGlobals() {
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
llvm::Constant *cst;
llvm::Type *type;
// String attributes are treated separately because they cannot appear as
// in-function constants and are thus not supported by getLLVMConstant.
if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
llvm::Constant *cst = llvm::ConstantDataArray::getString(
cst = llvm::ConstantDataArray::getString(
llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
new llvm::GlobalVariable(*llvmModule, cst->getType(), op.constant(),
llvm::GlobalValue::InternalLinkage, cst,
op.sym_name());
return;
type = cst->getType();
} else {
type = op.getType().getUnderlyingType();
cst = getLLVMConstant(type, op.value(), op.getLoc());
}
llvm::Type *type = op.getType().getUnderlyingType();
new llvm::GlobalVariable(
*llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage,
getLLVMConstant(type, op.value(), op.getLoc()), op.sym_name());
auto *var = new llvm::GlobalVariable(*llvmModule, type, op.constant(),
llvm::GlobalValue::InternalLinkage,
cst, op.sym_name());
globalsMapping.try_emplace(op, var);
}
}