[llvm] Allow GlobalOp to take a region for complex initializers

This allows GlobalOp to either take a value attribute (for simple constants) or a region that can
contain IR instructions (that must be constant-foldable) to create a ConstantExpr initializer.

Example:
  // A complex initializer is constructed with an initializer region.
  llvm.mlir.global constant @int_gep() : !llvm<"i32*"> {
    %0 = llvm.mlir.addressof @g2 : !llvm<"i32*">
    %1 = llvm.mlir.constant(2 : i32) : !llvm.i32
    %2 = llvm.getelementptr %0[%1] : (!llvm<"i32*">, !llvm.i32) -> !llvm<"i32*">
    llvm.return %2 : !llvm<"i32*">
  }
PiperOrigin-RevId: 278717836
Change-Id: I54ed196a361dd2ca4c564570d0ca9ed12e2f1f95
This commit is contained in:
A. Unique TensorFlower 2019-11-05 15:10:28 -08:00 committed by TensorFlower Gardener
parent 1195278c1b
commit b422cb120c
4 changed files with 92 additions and 22 deletions

View File

@ -495,10 +495,28 @@ def LLVM_AddressOfOp
} }
def LLVM_GlobalOp def LLVM_GlobalOp
: LLVM_ZeroResultOp<"mlir.global", [Symbol]>, : LLVM_ZeroResultOp<"mlir.global",
[IsolatedFromAbove,
SingleBlockImplicitTerminator<"ReturnOp">, Symbol]>,
Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name, Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,
OptionalAttr<AnyAttr>:$value, OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<NonNegativeI32Attr, "0">:$addr_space)> { DefaultValuedAttr<NonNegativeI32Attr, "0">:$addr_space)> {
let summary = "LLVM dialect global.";
let description = [{
Can contain an optional initializer region or attribute for simple
initializers.
Examples:
// Initialized using an attribute.
llvm.mlir.global @a("abc") : !llvm<"[3 x i8]">
// Initialized using a region.
llvm.mlir.global constant @b() : !llvm<"i32*"> {
%0 = llvm.constant(0 : i32) : !llvm.i32
%1 = llvm.inttoptr %0 : !llvm.i32 to !llvm<"i32*">
llvm.return %1 : !llvm<"i32*">
}
}];
let regions = (region AnyRegion:$initializer);
let builders = [ let builders = [
OpBuilder<"Builder *builder, OperationState &result, LLVMType type, " OpBuilder<"Builder *builder, OperationState &result, LLVMType type, "
@ -511,10 +529,22 @@ def LLVM_GlobalOp
LLVMType getType() { LLVMType getType() {
return type().cast<LLVMType>(); return type().cast<LLVMType>();
} }
/// Return the value attribute if it exists, or a null attribute. /// Return the initializer attribute if it exists, or a null attribute.
Attribute getValueOrNull() { Attribute getValueOrNull() {
return value().getValueOr(Attribute()); return value().getValueOr(Attribute());
} }
/// Return the initializer region. This may be empty, but if it is not it
/// terminates in an `llvm.return` op with the initializer value.
Region &getInitializerRegion() {
return getOperation()->getRegion(0);
}
/// Return the initializer block. If the initializer region is empty this
/// is nullptr. If it is not nullptr, it terminates with an `llvm.return`
/// op with the initializer value.
Block *getInitializerBlock() {
return getInitializerRegion().empty() ?
nullptr : &getInitializerRegion().front();
}
}]; }];
let printer = "printGlobalOp(p, *this);"; let printer = "printGlobalOp(p, *this);";

View File

@ -875,6 +875,7 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
if (value) if (value)
result.addAttribute("value", value); result.addAttribute("value", value);
result.attributes.append(attrs.begin(), attrs.end()); result.attributes.append(attrs.begin(), attrs.end());
result.addRegion();
} }
static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
@ -894,10 +895,14 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
return; return;
p << " : "; p << " : ";
p.printType(op.type()); p.printType(op.type());
Region &initializer = op.getInitializerRegion();
if (!initializer.empty())
p.printRegion(initializer, /*printEntryBlockArgs=*/false);
} }
// <operation> ::= `llvm.mlir.global` `constant`? `@` identifier // <operation> ::= `llvm.mlir.global` `constant`? `@` identifier
// `(` attribute? `)` attribute-list? (`:` type)? // `(` attribute? `)` attribute-list? (`:` type)? region?
// //
// The type can be omitted for string attributes, in which case it will be // The type can be omitted for string attributes, in which case it will be
// inferred from the value of the string as [strlen(value) x i8]. // inferred from the value of the string as [strlen(value) x i8].
@ -926,6 +931,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
if (types.size() > 1) if (types.size() > 1)
return parser.emitError(parser.getNameLoc(), "expected zero or one type"); return parser.emitError(parser.getNameLoc(), "expected zero or one type");
Region &initRegion = *result.addRegion();
if (types.empty()) { if (types.empty()) {
if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
MLIRContext *context = parser.getBuilder().getContext(); MLIRContext *context = parser.getBuilder().getContext();
@ -937,6 +943,9 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(parser.getNameLoc(), return parser.emitError(parser.getNameLoc(),
"type can only be omitted for string globals"); "type can only be omitted for string globals");
} }
} else if (parser.parseOptionalRegion(initRegion, /*arguments=*/{},
/*argTypes=*/{})) {
return failure();
} }
result.addAttribute("type", TypeAttr::get(types[0])); result.addAttribute("type", TypeAttr::get(types[0]));
@ -959,6 +968,19 @@ static LogicalResult verify(GlobalOp op) {
"requires an i8 array type of the length equal to that of the string " "requires an i8 array type of the length equal to that of the string "
"attribute"); "attribute");
} }
if (Block *b = op.getInitializerBlock()) {
ReturnOp ret = cast<ReturnOp>(b->getTerminator());
if (ret.operand_type_begin() == ret.operand_type_end())
return op.emitOpError("initializer region cannot return void");
if (*ret.operand_type_begin() != op.getType())
return op.emitOpError("initializer region type ")
<< *ret.operand_type_begin() << " does not match global type "
<< op.getType();
if (op.getValueOrNull())
return op.emitOpError("cannot have both initializer value and region");
}
return success(); return success();
} }

View File

@ -212,8 +212,6 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (auto *c = dyn_cast<llvm::ConstantDataArray>(value)) if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
if (c->isString()) if (c->isString())
return b.getStringAttr(c->getAsString()); return b.getStringAttr(c->getAsString());
emitError(unknownLoc) << "unhandled constant type for attribute: "
<< diag(*value);
return Attribute(); return Attribute();
} }
@ -226,17 +224,25 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
Attribute valueAttr; Attribute valueAttr;
if (GV->hasInitializer()) if (GV->hasInitializer())
valueAttr = getConstantAsAttr(GV->getInitializer()); valueAttr = getConstantAsAttr(GV->getInitializer());
return globals[GV] = b.create<GlobalOp>( GlobalOp op = b.create<GlobalOp>(UnknownLoc::get(context),
UnknownLoc::get(context), processType(GV->getValueType()), processType(GV->getValueType()),
GV->isConstant(), GV->getName(), valueAttr); GV->isConstant(), GV->getName(), valueAttr);
if (GV->hasInitializer() && !valueAttr) {
Region &r = op.getInitializerRegion();
currentEntryBlock = b.createBlock(&r);
b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
Value *v = processConstant(GV->getInitializer());
b.create<ReturnOp>(op.getLoc(), ArrayRef<Value *>({v}));
}
return globals[GV] = op;
} }
Value *Importer::processConstant(llvm::Constant *c) { Value *Importer::processConstant(llvm::Constant *c) {
if (isa<llvm::ConstantInt>(c) || isa<llvm::ConstantDataArray>(c)) { if (Attribute attr = getConstantAsAttr(c)) {
// These constants can be represented as attributes. // These constants can be represented as attributes.
OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
return instMap[c] = b.create<ConstantOp>( return instMap[c] = b.create<ConstantOp>(unknownLoc,
unknownLoc, processType(c->getType()), getConstantAsAttr(c)); processType(c->getType()), attr);
} }
if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) { if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); OpBuilder b(currentEntryBlock, currentEntryBlock->begin());

View File

@ -283,17 +283,29 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
// definitions. // definitions.
void ModuleTranslation::convertGlobals() { void ModuleTranslation::convertGlobals() {
for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) { for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
llvm::Constant *cst; llvm::Type *type = op.getType().getUnderlyingType();
llvm::Type *type; llvm::Constant *cst = llvm::UndefValue::get(type);
// String attributes are treated separately because they cannot appear as if (op.getValueOrNull()) {
// in-function constants and are thus not supported by getLLVMConstant. // String attributes are treated separately because they cannot appear as
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { // in-function constants and are thus not supported by getLLVMConstant.
cst = llvm::ConstantDataArray::getString( if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); cst = llvm::ConstantDataArray::getString(
type = cst->getType(); llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
} else { type = cst->getType();
type = op.getType().getUnderlyingType(); } else {
cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc()); cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc());
}
} else if (Block *initializer = op.getInitializerBlock()) {
llvm::IRBuilder<> builder(llvmModule->getContext());
for (auto &op : initializer->without_terminator()) {
if (failed(convertOperation(op, builder)) ||
!isa<llvm::Constant>(valueMapping.lookup(op.getResult(0)))) {
emitError(op.getLoc(), "unemittable constant value");
return;
}
}
ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
cst = cast<llvm::Constant>(valueMapping.lookup(ret.getOperand(0)));
} }
auto addrSpace = op.addr_space().getLimitedValue(); auto addrSpace = op.addr_space().getLimitedValue();