[spirv] Add support for specialization constant

This CL extends the existing spv.constant op to also support
specialization constant by adding an extra unit attribute
on it.

PiperOrigin-RevId: 261194869
This commit is contained in:
Lei Zhang 2019-08-01 14:12:58 -07:00 committed by TensorFlower Gardener
parent 14c3c2c9fb
commit 851f0cc219
5 changed files with 215 additions and 151 deletions

View File

@ -72,58 +72,62 @@ class SPV_OpCode<string name, int val> {
// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>;
def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>;
def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>;
def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>;
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>;
def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>;
def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>;
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>;
def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>;
def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>;
def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>;
def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>;
def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>;
def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>;
def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>;
def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>;
def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>;
def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>;
def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>;
def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>;
def SPV_OC_OpSpecConstantTrue : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>;
def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>;
def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>;
def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>;
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>;
def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>;
def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>;
def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>;
def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>;
def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>;
def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>;
def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>;
def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>;
def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@ -132,14 +136,15 @@ def SPV_OpcodeAttr :
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn
]> {

View File

@ -152,23 +152,24 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
### Custom assembly form
``` {.ebnf}
spv-constant-op ::= ssa-id `=` `spv.constant` attribute-value
spv-constant-op ::= ssa-id `=` `spv.constant` (`spec`)? attribute-value
(`:` spirv-type)?
```
For example:
```
%0 = spv.constant true
%1 = spv.constant dense<vector<2xf32>, [2, 3]>
%2 = spv.constant [dense<vector<2xf32>, 3.0>] : !spv.array<1xvector<2xf32>>
%0 = spv.constant spec true
%1 = spv.constant dense<[2, 3]> : vector<2xf32>
%2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
```
TODO(antiagainst): support constant structs
}];
let arguments = (ins
AnyAttr:$value
AnyAttr:$value,
UnitAttr:$is_spec_const
);
let results = (outs

View File

@ -33,6 +33,7 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kIsSpecConstName[] = "is_spec_const";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
static constexpr const char kFnNameAttrName[] = "fn";
@ -466,6 +467,9 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
//===----------------------------------------------------------------------===//
static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
if (succeeded(parser->parseOptionalKeyword("spec")))
state->addAttribute(kIsSpecConstName, parser->getBuilder().getUnitAttr());
Attribute value;
if (parser->parseAttribute(value, kValueAttrName, state->attributes))
return failure();
@ -482,7 +486,8 @@ static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
}
static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
*printer << spirv::ConstantOp::getOperationName() << " " << constOp.value();
*printer << spirv::ConstantOp::getOperationName()
<< (constOp.is_spec_const() ? " spec " : " ") << constOp.value();
if (constOp.getType().isa<spirv::ArrayType>()) {
*printer << " : " << constOp.getType();
}

View File

@ -115,16 +115,20 @@ private:
// Constant
//===--------------------------------------------------------------------===//
/// Processes a SPIR-V OpConstant instruction with the given `operands`.
LogicalResult processConstant(ArrayRef<uint32_t> operands);
/// Processes a SPIR-V Op{|Spec}Constant instruction with the given
/// `operands`. `isSpec` indicates whether this is a specialization constant.
LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
/// Processes a SPIR-V OpConstant{True|False} instruction with the given
/// `operands`.
LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands);
/// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
/// given `operands`. `isSpec` indicates whether this is a specialization
/// constant.
LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
bool isSpec);
/// Processes a SPIR-V OpConstantComposite instruction with the given
/// `operands`.
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
/// Processes a SPIR-V Op{|Spec}ConstantComposite instruction with the given
/// `operands`. `isSpec` indicates whether this is a specialization constant.
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands,
bool isSpec);
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
@ -610,14 +614,17 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
// Constant
//===----------------------------------------------------------------------===//
LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
bool isSpec) {
StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
if (operands.size() < 2) {
return emitError(unknownLoc,
"OpConstant must have type <id> and result <id>");
return emitError(unknownLoc)
<< opname << " must have type <id> and result <id>";
}
if (operands.size() < 3) {
return emitError(unknownLoc,
"OpConstant must have at least 1 more parameter");
return emitError(unknownLoc)
<< opname << " must have at least 1 more parameter";
}
Type resultType = getType(operands[0]);
@ -631,22 +638,24 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
if (operands.size() == 4) {
return success();
}
return emitError(unknownLoc,
"OpConstant should have 2 parameters for 64-bit values");
return emitError(unknownLoc)
<< opname << " should have 2 parameters for 64-bit values";
}
if (bitwidth <= 32) {
if (operands.size() == 3) {
return success();
}
return emitError(unknownLoc, "OpConstant should have 1 parameter for "
"values with no more than 32 bits");
return emitError(unknownLoc)
<< opname
<< " should have 1 parameter for values with no more than 32 bits";
}
return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
<< bitwidth;
};
spirv::ConstantOp op;
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
if (auto intType = resultType.dyn_cast<IntegerType>()) {
auto bitwidth = intType.getWidth();
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
@ -668,7 +677,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
}
auto attr = opBuilder.getIntegerAttr(intType, value);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr,
isSpecConst);
} else if (auto floatType = resultType.dyn_cast<FloatType>()) {
auto bitwidth = floatType.getWidth();
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
@ -693,7 +703,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
}
auto attr = opBuilder.getFloatAttr(floatType, value);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr,
isSpecConst);
} else {
return emitError(unknownLoc, "OpConstant can only generate values of "
"scalar integer or floating-point type");
@ -704,23 +715,27 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
}
LogicalResult Deserializer::processConstantBool(bool isTrue,
ArrayRef<uint32_t> operands) {
ArrayRef<uint32_t> operands,
bool isSpec) {
if (operands.size() != 2) {
return emitError(unknownLoc, "OpConstant")
return emitError(unknownLoc, "Op")
<< (isSpec ? "Spec" : "") << "Constant"
<< (isTrue ? "True" : "False")
<< " must have type <id> and result <id>";
}
auto attr = opBuilder.getBoolAttr(isTrue);
auto op = opBuilder.create<spirv::ConstantOp>(unknownLoc,
opBuilder.getI1Type(), attr);
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
auto op = opBuilder.create<spirv::ConstantOp>(
unknownLoc, opBuilder.getI1Type(), attr, isSpecConst);
valueMap[operands[1]] = op.getResult();
return success();
}
LogicalResult
Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
Deserializer::processConstantComposite(ArrayRef<uint32_t> operands,
bool isSpec) {
if (operands.size() < 2) {
return emitError(unknownLoc,
"OpConstantComposite must have type <id> and result <id>");
@ -757,12 +772,15 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
spirv::ConstantOp op;
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
if (auto vectorType = resultType.dyn_cast<VectorType>()) {
auto attr = opBuilder.getDenseElementsAttr(vectorType, elements);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
isSpecConst);
} else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
auto attr = opBuilder.getArrayAttr(elements);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
isSpecConst);
} else {
return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
<< resultType;
@ -788,7 +806,9 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
resultType.isa<VectorType>()) {
auto attr = opBuilder.getZeroAttr(resultType);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
UnitAttr isSpecConst;
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
isSpecConst);
} else {
return emitError(unknownLoc, "unsupported OpConstantNull type: ")
<< resultType;
@ -859,13 +879,21 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypePointer:
return processType(opcode, operands);
case spirv::Opcode::OpConstant:
return processConstant(operands);
return processConstant(operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstant:
return processConstant(operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantComposite:
return processConstantComposite(operands);
return processConstantComposite(operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstantComposite:
return processConstantComposite(operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantTrue:
return processConstantBool(true, operands);
return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstantTrue:
return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantFalse:
return processConstantBool(false, operands);
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstantFalse:
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
case spirv::Opcode::OpDecorate:

View File

@ -168,15 +168,17 @@ private:
/// and `valueAttr`. `constType` is needed here because we can interpret the
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
/// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
/// constants.
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
/// constants. If `isSpec` is true, then the constant will be serialized as
/// a specialization constant.
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
bool isSpec);
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareBoolVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
spirv::Opcode &opcode,
bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
@ -184,7 +186,7 @@ private:
/// constant to `operands`.
LogicalResult prepareIntVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
spirv::Opcode &opcode,
bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares float ElementsAttr serialization. This method updates `opcode`
@ -192,14 +194,14 @@ private:
/// constant to `operands`.
LogicalResult prepareFloatVectorConstant(Location loc,
DenseFPElementsAttr elementsAttr,
spirv::Opcode &opcode,
bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr);
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
//===--------------------------------------------------------------------===//
// Operations
@ -317,7 +319,8 @@ LogicalResult Serializer::processMemoryModel() {
}
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
op.is_spec_const())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
@ -484,7 +487,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
}
operands.push_back(elementTypeID);
if (auto elementCountID = prepareConstantInt(
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
/*isSpec=*/false)) {
operands.push_back(elementCountID);
return success();
}
@ -535,15 +539,15 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
//===----------------------------------------------------------------------===//
uint32_t Serializer::prepareConstant(Location loc, Type constType,
Attribute valueAttr) {
Attribute valueAttr, bool isSpec) {
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
return prepareConstantFp(loc, floatAttr);
return prepareConstantFp(loc, floatAttr, isSpec);
}
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
return prepareConstantInt(loc, intAttr);
return prepareConstantInt(loc, intAttr, isSpec);
}
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
return prepareConstantBool(loc, boolAttr);
return prepareConstantBool(loc, boolAttr, isSpec);
}
// This is a composite literal. We need to handle each component separately
@ -566,21 +570,25 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
if (vectorAttr.getType().getElementType().isInteger(1)) {
if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
operands)))
return 0;
} else if (failed(
prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
} else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
operands)))
return 0;
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
operands)))
return 0;
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
opcode = spirv::Opcode::OpConstantComposite;
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.reserve(arrayAttr.size() + 2);
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : arrayAttr)
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
if (auto elementID =
prepareConstant(loc, elementType, elementAttr, isSpec)) {
operands.push_back(elementID);
} else {
return 0;
@ -596,8 +604,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
}
LogicalResult Serializer::prepareBoolVectorConstant(
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands) {
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@ -612,13 +620,15 @@ LogicalResult Serializer::prepareBoolVectorConstant(
// the splat value is zero.
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
if (!splatAttr.cast<BoolAttr>().getValue()) {
if (!isSpec && !splatAttr.cast<BoolAttr>().getValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantBool(loc, splatAttr.cast<BoolAttr>())) {
opcode = spirv::Opcode::OpConstantComposite;
if (auto id =
prepareConstantBool(loc, splatAttr.cast<BoolAttr>(), isSpec)) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@ -628,13 +638,14 @@ LogicalResult Serializer::prepareBoolVectorConstant(
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
opcode = spirv::Opcode::OpConstantComposite;
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
for (APInt intValue : elementsAttr) {
// We are constructing an BoolAttr for each APInt here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue());
if (auto elementID = prepareConstantBool(loc, boolAttr)) {
if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
@ -644,8 +655,8 @@ LogicalResult Serializer::prepareBoolVectorConstant(
}
LogicalResult Serializer::prepareIntVectorConstant(
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands) {
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@ -661,13 +672,15 @@ LogicalResult Serializer::prepareIntVectorConstant(
// the splat value is zero.
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
if (splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
if (!isSpec && splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantInt(loc, splatAttr.cast<IntegerAttr>())) {
opcode = spirv::Opcode::OpConstantComposite;
if (auto id =
prepareConstantInt(loc, splatAttr.cast<IntegerAttr>(), isSpec)) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@ -676,7 +689,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
opcode = spirv::Opcode::OpConstantComposite;
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
for (APInt intValue : elementsAttr) {
// We are constructing an IntegerAttr for each APInt here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
@ -684,7 +698,7 @@ LogicalResult Serializer::prepareIntVectorConstant(
// TODO(antiagainst): revisit this if special extensions enabling large
// vectors are supported.
auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue);
if (auto elementID = prepareConstantInt(loc, intAttr)) {
if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
@ -694,8 +708,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
}
LogicalResult Serializer::prepareFloatVectorConstant(
Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands) {
Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@ -706,13 +720,14 @@ LogicalResult Serializer::prepareFloatVectorConstant(
operands.reserve(count + 2);
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
if (splatAttr.cast<FloatAttr>().getValue().isZero()) {
if (!isSpec && splatAttr.cast<FloatAttr>().getValue().isZero()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>())) {
opcode = spirv::Opcode::OpConstantComposite;
if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>(), isSpec)) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@ -720,10 +735,11 @@ LogicalResult Serializer::prepareFloatVectorConstant(
return failure();
}
opcode = spirv::Opcode::OpConstantComposite;
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
for (APFloat floatValue : elementsAttr) {
auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue);
if (auto elementID = prepareConstantFp(loc, fpAttr)) {
if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
@ -732,7 +748,8 @@ LogicalResult Serializer::prepareFloatVectorConstant(
return success();
}
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
bool isSpec) {
if (auto id = findConstantID(boolAttr)) {
return id;
}
@ -744,14 +761,18 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
}
auto resultID = getNextID();
auto opcode = boolAttr.getValue() ? spirv::Opcode::OpConstantTrue
: spirv::Opcode::OpConstantFalse;
auto opcode = boolAttr.getValue()
? (isSpec ? spirv::Opcode::OpSpecConstantTrue
: spirv::Opcode::OpConstantTrue)
: (isSpec ? spirv::Opcode::OpSpecConstantFalse
: spirv::Opcode::OpConstantFalse);
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
return constIDMap[boolAttr] = resultID;
}
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
bool isSpec) {
if (auto id = findConstantID(intAttr)) {
return id;
}
@ -767,6 +788,9 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
unsigned bitwidth = value.getBitWidth();
bool isSigned = value.isSignedIntN(bitwidth);
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
// According to SPIR-V spec, "When the type's bit width is less than 32-bits,
// the literal's value appears in the low-order bits of the word, and the
// high-order bits must be 0 for a floating-point type, or 0 for an integer
@ -778,8 +802,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
} else {
word = static_cast<uint32_t>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
{typeID, resultID, word});
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
}
// According to SPIR-V spec: "When the type's bit width is larger than one
// word, the literals low-order words appear first."
@ -793,7 +816,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
} else {
words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else {
std::string valueStr;
@ -808,7 +831,8 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
return constIDMap[intAttr] = resultID;
}
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec) {
if (auto id = findConstantID(floatAttr)) {
return id;
}
@ -823,22 +847,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
APFloat value = floatAttr.getValue();
APInt intValue = value.bitcastToAPInt();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
{typeID, resultID, word});
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
{typeID, resultID, word});
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);