[spirv] Add support for specialization constant
This CL extends the existing spv.constant op to also support specialization constant by adding an extra unit attribute on it. PiperOrigin-RevId: 261194869
This commit is contained in:
parent
14c3c2c9fb
commit
851f0cc219
@ -72,58 +72,62 @@ class SPV_OpCode<string name, int val> {
|
||||
|
||||
// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
|
||||
|
||||
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
|
||||
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
|
||||
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
|
||||
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
|
||||
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
|
||||
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
|
||||
def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
|
||||
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
|
||||
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
|
||||
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
|
||||
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
|
||||
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
|
||||
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
|
||||
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
|
||||
def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>;
|
||||
def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>;
|
||||
def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>;
|
||||
def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>;
|
||||
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
|
||||
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
|
||||
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
|
||||
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
|
||||
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
||||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
|
||||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
|
||||
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
|
||||
def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>;
|
||||
def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>;
|
||||
def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>;
|
||||
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
|
||||
def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>;
|
||||
def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>;
|
||||
def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>;
|
||||
def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>;
|
||||
def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
|
||||
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
|
||||
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
|
||||
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
|
||||
def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>;
|
||||
def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>;
|
||||
def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>;
|
||||
def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>;
|
||||
def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
|
||||
def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
|
||||
def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>;
|
||||
def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
|
||||
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
|
||||
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
|
||||
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
|
||||
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
|
||||
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
|
||||
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
|
||||
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
|
||||
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
|
||||
def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
|
||||
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
|
||||
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
|
||||
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
|
||||
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
|
||||
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
|
||||
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
|
||||
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
|
||||
def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>;
|
||||
def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>;
|
||||
def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>;
|
||||
def SPV_OC_OpConstantNull : I32EnumAttrCase<"OpConstantNull", 46>;
|
||||
def SPV_OC_OpSpecConstantTrue : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
|
||||
def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
|
||||
def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>;
|
||||
def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
|
||||
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
|
||||
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
|
||||
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
|
||||
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
|
||||
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
||||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
|
||||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
|
||||
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
|
||||
def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>;
|
||||
def SPV_OC_OpFSub : I32EnumAttrCase<"OpFSub", 131>;
|
||||
def SPV_OC_OpIMul : I32EnumAttrCase<"OpIMul", 132>;
|
||||
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
|
||||
def SPV_OC_OpUDiv : I32EnumAttrCase<"OpUDiv", 134>;
|
||||
def SPV_OC_OpSDiv : I32EnumAttrCase<"OpSDiv", 135>;
|
||||
def SPV_OC_OpFDiv : I32EnumAttrCase<"OpFDiv", 136>;
|
||||
def SPV_OC_OpUMod : I32EnumAttrCase<"OpUMod", 137>;
|
||||
def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
|
||||
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
|
||||
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
|
||||
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
|
||||
def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>;
|
||||
def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>;
|
||||
def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>;
|
||||
def SPV_OC_OpSGreaterThan : I32EnumAttrCase<"OpSGreaterThan", 173>;
|
||||
def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
|
||||
def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
|
||||
def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>;
|
||||
def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
|
||||
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
|
||||
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
|
||||
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
|
||||
def SPV_OpcodeAttr :
|
||||
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
|
||||
@ -132,14 +136,15 @@ def SPV_OpcodeAttr :
|
||||
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
|
||||
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
|
||||
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
|
||||
SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
|
||||
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
|
||||
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
|
||||
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
|
||||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
|
||||
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
|
||||
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
|
||||
SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
|
||||
SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
|
||||
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn
|
||||
]> {
|
||||
|
||||
@ -152,23 +152,24 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
spv-constant-op ::= ssa-id `=` `spv.constant` attribute-value
|
||||
spv-constant-op ::= ssa-id `=` `spv.constant` (`spec`)? attribute-value
|
||||
(`:` spirv-type)?
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.constant true
|
||||
%1 = spv.constant dense<vector<2xf32>, [2, 3]>
|
||||
%2 = spv.constant [dense<vector<2xf32>, 3.0>] : !spv.array<1xvector<2xf32>>
|
||||
%0 = spv.constant spec true
|
||||
%1 = spv.constant dense<[2, 3]> : vector<2xf32>
|
||||
%2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
|
||||
```
|
||||
|
||||
TODO(antiagainst): support constant structs
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr:$value
|
||||
AnyAttr:$value,
|
||||
UnitAttr:$is_spec_const
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
||||
@ -33,6 +33,7 @@ using namespace mlir;
|
||||
// TODO(antiagainst): generate these strings using ODS.
|
||||
static constexpr const char kAlignmentAttrName[] = "alignment";
|
||||
static constexpr const char kIndicesAttrName[] = "indices";
|
||||
static constexpr const char kIsSpecConstName[] = "is_spec_const";
|
||||
static constexpr const char kValueAttrName[] = "value";
|
||||
static constexpr const char kValuesAttrName[] = "values";
|
||||
static constexpr const char kFnNameAttrName[] = "fn";
|
||||
@ -466,6 +467,9 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
|
||||
if (succeeded(parser->parseOptionalKeyword("spec")))
|
||||
state->addAttribute(kIsSpecConstName, parser->getBuilder().getUnitAttr());
|
||||
|
||||
Attribute value;
|
||||
if (parser->parseAttribute(value, kValueAttrName, state->attributes))
|
||||
return failure();
|
||||
@ -482,7 +486,8 @@ static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
|
||||
}
|
||||
|
||||
static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ConstantOp::getOperationName() << " " << constOp.value();
|
||||
*printer << spirv::ConstantOp::getOperationName()
|
||||
<< (constOp.is_spec_const() ? " spec " : " ") << constOp.value();
|
||||
if (constOp.getType().isa<spirv::ArrayType>()) {
|
||||
*printer << " : " << constOp.getType();
|
||||
}
|
||||
|
||||
@ -115,16 +115,20 @@ private:
|
||||
// Constant
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Processes a SPIR-V OpConstant instruction with the given `operands`.
|
||||
LogicalResult processConstant(ArrayRef<uint32_t> operands);
|
||||
/// Processes a SPIR-V Op{|Spec}Constant instruction with the given
|
||||
/// `operands`. `isSpec` indicates whether this is a specialization constant.
|
||||
LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
|
||||
|
||||
/// Processes a SPIR-V OpConstant{True|False} instruction with the given
|
||||
/// `operands`.
|
||||
LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands);
|
||||
/// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
|
||||
/// given `operands`. `isSpec` indicates whether this is a specialization
|
||||
/// constant.
|
||||
LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
|
||||
bool isSpec);
|
||||
|
||||
/// Processes a SPIR-V OpConstantComposite instruction with the given
|
||||
/// `operands`.
|
||||
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
|
||||
/// Processes a SPIR-V Op{|Spec}ConstantComposite instruction with the given
|
||||
/// `operands`. `isSpec` indicates whether this is a specialization constant.
|
||||
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands,
|
||||
bool isSpec);
|
||||
|
||||
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
|
||||
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
|
||||
@ -610,14 +614,17 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
|
||||
// Constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
||||
LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
|
||||
bool isSpec) {
|
||||
StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
|
||||
|
||||
if (operands.size() < 2) {
|
||||
return emitError(unknownLoc,
|
||||
"OpConstant must have type <id> and result <id>");
|
||||
return emitError(unknownLoc)
|
||||
<< opname << " must have type <id> and result <id>";
|
||||
}
|
||||
if (operands.size() < 3) {
|
||||
return emitError(unknownLoc,
|
||||
"OpConstant must have at least 1 more parameter");
|
||||
return emitError(unknownLoc)
|
||||
<< opname << " must have at least 1 more parameter";
|
||||
}
|
||||
|
||||
Type resultType = getType(operands[0]);
|
||||
@ -631,22 +638,24 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() == 4) {
|
||||
return success();
|
||||
}
|
||||
return emitError(unknownLoc,
|
||||
"OpConstant should have 2 parameters for 64-bit values");
|
||||
return emitError(unknownLoc)
|
||||
<< opname << " should have 2 parameters for 64-bit values";
|
||||
}
|
||||
if (bitwidth <= 32) {
|
||||
if (operands.size() == 3) {
|
||||
return success();
|
||||
}
|
||||
|
||||
return emitError(unknownLoc, "OpConstant should have 1 parameter for "
|
||||
"values with no more than 32 bits");
|
||||
return emitError(unknownLoc)
|
||||
<< opname
|
||||
<< " should have 1 parameter for values with no more than 32 bits";
|
||||
}
|
||||
return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
|
||||
<< bitwidth;
|
||||
};
|
||||
|
||||
spirv::ConstantOp op;
|
||||
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
|
||||
if (auto intType = resultType.dyn_cast<IntegerType>()) {
|
||||
auto bitwidth = intType.getWidth();
|
||||
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
|
||||
@ -668,7 +677,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
||||
}
|
||||
|
||||
auto attr = opBuilder.getIntegerAttr(intType, value);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr,
|
||||
isSpecConst);
|
||||
} else if (auto floatType = resultType.dyn_cast<FloatType>()) {
|
||||
auto bitwidth = floatType.getWidth();
|
||||
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
|
||||
@ -693,7 +703,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
||||
}
|
||||
|
||||
auto attr = opBuilder.getFloatAttr(floatType, value);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr,
|
||||
isSpecConst);
|
||||
} else {
|
||||
return emitError(unknownLoc, "OpConstant can only generate values of "
|
||||
"scalar integer or floating-point type");
|
||||
@ -704,23 +715,27 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processConstantBool(bool isTrue,
|
||||
ArrayRef<uint32_t> operands) {
|
||||
ArrayRef<uint32_t> operands,
|
||||
bool isSpec) {
|
||||
if (operands.size() != 2) {
|
||||
return emitError(unknownLoc, "OpConstant")
|
||||
return emitError(unknownLoc, "Op")
|
||||
<< (isSpec ? "Spec" : "") << "Constant"
|
||||
<< (isTrue ? "True" : "False")
|
||||
<< " must have type <id> and result <id>";
|
||||
}
|
||||
|
||||
auto attr = opBuilder.getBoolAttr(isTrue);
|
||||
auto op = opBuilder.create<spirv::ConstantOp>(unknownLoc,
|
||||
opBuilder.getI1Type(), attr);
|
||||
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
|
||||
auto op = opBuilder.create<spirv::ConstantOp>(
|
||||
unknownLoc, opBuilder.getI1Type(), attr, isSpecConst);
|
||||
|
||||
valueMap[operands[1]] = op.getResult();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
|
||||
Deserializer::processConstantComposite(ArrayRef<uint32_t> operands,
|
||||
bool isSpec) {
|
||||
if (operands.size() < 2) {
|
||||
return emitError(unknownLoc,
|
||||
"OpConstantComposite must have type <id> and result <id>");
|
||||
@ -757,12 +772,15 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
|
||||
}
|
||||
|
||||
spirv::ConstantOp op;
|
||||
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
|
||||
if (auto vectorType = resultType.dyn_cast<VectorType>()) {
|
||||
auto attr = opBuilder.getDenseElementsAttr(vectorType, elements);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
|
||||
isSpecConst);
|
||||
} else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
|
||||
auto attr = opBuilder.getArrayAttr(elements);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
|
||||
isSpecConst);
|
||||
} else {
|
||||
return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
|
||||
<< resultType;
|
||||
@ -788,7 +806,9 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
|
||||
if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
|
||||
resultType.isa<VectorType>()) {
|
||||
auto attr = opBuilder.getZeroAttr(resultType);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
|
||||
UnitAttr isSpecConst;
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
|
||||
isSpecConst);
|
||||
} else {
|
||||
return emitError(unknownLoc, "unsupported OpConstantNull type: ")
|
||||
<< resultType;
|
||||
@ -859,13 +879,21 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
||||
case spirv::Opcode::OpTypePointer:
|
||||
return processType(opcode, operands);
|
||||
case spirv::Opcode::OpConstant:
|
||||
return processConstant(operands);
|
||||
return processConstant(operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstant:
|
||||
return processConstant(operands, /*isSpec=*/true);
|
||||
case spirv::Opcode::OpConstantComposite:
|
||||
return processConstantComposite(operands);
|
||||
return processConstantComposite(operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstantComposite:
|
||||
return processConstantComposite(operands, /*isSpec=*/true);
|
||||
case spirv::Opcode::OpConstantTrue:
|
||||
return processConstantBool(true, operands);
|
||||
return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstantTrue:
|
||||
return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
|
||||
case spirv::Opcode::OpConstantFalse:
|
||||
return processConstantBool(false, operands);
|
||||
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstantFalse:
|
||||
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
|
||||
case spirv::Opcode::OpConstantNull:
|
||||
return processConstantNull(operands);
|
||||
case spirv::Opcode::OpDecorate:
|
||||
|
||||
@ -168,15 +168,17 @@ private:
|
||||
/// and `valueAttr`. `constType` is needed here because we can interpret the
|
||||
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
|
||||
/// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
|
||||
/// constants.
|
||||
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
|
||||
/// constants. If `isSpec` is true, then the constant will be serialized as
|
||||
/// a specialization constant.
|
||||
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
|
||||
bool isSpec);
|
||||
|
||||
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
|
||||
/// with a proper OpConstant* instruction and pushes literal values for the
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareBoolVectorConstant(Location loc,
|
||||
DenseIntElementsAttr elementsAttr,
|
||||
spirv::Opcode &opcode,
|
||||
bool isSpec, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
|
||||
@ -184,7 +186,7 @@ private:
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareIntVectorConstant(Location loc,
|
||||
DenseIntElementsAttr elementsAttr,
|
||||
spirv::Opcode &opcode,
|
||||
bool isSpec, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
/// Prepares float ElementsAttr serialization. This method updates `opcode`
|
||||
@ -192,14 +194,14 @@ private:
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareFloatVectorConstant(Location loc,
|
||||
DenseFPElementsAttr elementsAttr,
|
||||
spirv::Opcode &opcode,
|
||||
bool isSpec, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr);
|
||||
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
|
||||
|
||||
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
|
||||
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
|
||||
|
||||
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
|
||||
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operations
|
||||
@ -317,7 +319,8 @@ LogicalResult Serializer::processMemoryModel() {
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
||||
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
|
||||
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
|
||||
op.is_spec_const())) {
|
||||
valueIDMap[op.getResult()] = resultID;
|
||||
return success();
|
||||
}
|
||||
@ -484,7 +487,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
|
||||
}
|
||||
operands.push_back(elementTypeID);
|
||||
if (auto elementCountID = prepareConstantInt(
|
||||
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
|
||||
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
|
||||
/*isSpec=*/false)) {
|
||||
operands.push_back(elementCountID);
|
||||
return success();
|
||||
}
|
||||
@ -535,15 +539,15 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
||||
Attribute valueAttr) {
|
||||
Attribute valueAttr, bool isSpec) {
|
||||
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
|
||||
return prepareConstantFp(loc, floatAttr);
|
||||
return prepareConstantFp(loc, floatAttr, isSpec);
|
||||
}
|
||||
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
|
||||
return prepareConstantInt(loc, intAttr);
|
||||
return prepareConstantInt(loc, intAttr, isSpec);
|
||||
}
|
||||
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
|
||||
return prepareConstantBool(loc, boolAttr);
|
||||
return prepareConstantBool(loc, boolAttr, isSpec);
|
||||
}
|
||||
|
||||
// This is a composite literal. We need to handle each component separately
|
||||
@ -566,21 +570,25 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
||||
|
||||
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (vectorAttr.getType().getElementType().isInteger(1)) {
|
||||
if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
|
||||
operands)))
|
||||
return 0;
|
||||
} else if (failed(
|
||||
prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
} else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
|
||||
operands)))
|
||||
return 0;
|
||||
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
|
||||
operands)))
|
||||
return 0;
|
||||
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
operands.reserve(arrayAttr.size() + 2);
|
||||
|
||||
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
|
||||
for (Attribute elementAttr : arrayAttr)
|
||||
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
|
||||
if (auto elementID =
|
||||
prepareConstant(loc, elementType, elementAttr, isSpec)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return 0;
|
||||
@ -596,8 +604,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
||||
}
|
||||
|
||||
LogicalResult Serializer::prepareBoolVectorConstant(
|
||||
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
|
||||
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
@ -612,13 +620,15 @@ LogicalResult Serializer::prepareBoolVectorConstant(
|
||||
// the splat value is zero.
|
||||
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
|
||||
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
|
||||
if (!splatAttr.cast<BoolAttr>().getValue()) {
|
||||
if (!isSpec && !splatAttr.cast<BoolAttr>().getValue()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantBool(loc, splatAttr.cast<BoolAttr>())) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
if (auto id =
|
||||
prepareConstantBool(loc, splatAttr.cast<BoolAttr>(), isSpec)) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
@ -628,13 +638,14 @@ LogicalResult Serializer::prepareBoolVectorConstant(
|
||||
|
||||
// Otherwise, we need to process each element and compose them with
|
||||
// OpConstantComposite.
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
for (APInt intValue : elementsAttr) {
|
||||
// We are constructing an BoolAttr for each APInt here. But given that
|
||||
// we only use ElementsAttr for vectors with no more than 4 elements, it
|
||||
// should be fine here.
|
||||
auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue());
|
||||
if (auto elementID = prepareConstantBool(loc, boolAttr)) {
|
||||
if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
@ -644,8 +655,8 @@ LogicalResult Serializer::prepareBoolVectorConstant(
|
||||
}
|
||||
|
||||
LogicalResult Serializer::prepareIntVectorConstant(
|
||||
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
|
||||
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
@ -661,13 +672,15 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
||||
// the splat value is zero.
|
||||
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
|
||||
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
|
||||
if (splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
|
||||
if (!isSpec && splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantInt(loc, splatAttr.cast<IntegerAttr>())) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
if (auto id =
|
||||
prepareConstantInt(loc, splatAttr.cast<IntegerAttr>(), isSpec)) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
@ -676,7 +689,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
||||
|
||||
// Otherwise, we need to process each element and compose them with
|
||||
// OpConstantComposite.
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
for (APInt intValue : elementsAttr) {
|
||||
// We are constructing an IntegerAttr for each APInt here. But given that
|
||||
// we only use ElementsAttr for vectors with no more than 4 elements, it
|
||||
@ -684,7 +698,7 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
||||
// TODO(antiagainst): revisit this if special extensions enabling large
|
||||
// vectors are supported.
|
||||
auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue);
|
||||
if (auto elementID = prepareConstantInt(loc, intAttr)) {
|
||||
if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
@ -694,8 +708,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
||||
}
|
||||
|
||||
LogicalResult Serializer::prepareFloatVectorConstant(
|
||||
Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
|
||||
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
@ -706,13 +720,14 @@ LogicalResult Serializer::prepareFloatVectorConstant(
|
||||
operands.reserve(count + 2);
|
||||
|
||||
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
|
||||
if (splatAttr.cast<FloatAttr>().getValue().isZero()) {
|
||||
if (!isSpec && splatAttr.cast<FloatAttr>().getValue().isZero()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>())) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>(), isSpec)) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
@ -720,10 +735,11 @@ LogicalResult Serializer::prepareFloatVectorConstant(
|
||||
return failure();
|
||||
}
|
||||
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
for (APFloat floatValue : elementsAttr) {
|
||||
auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue);
|
||||
if (auto elementID = prepareConstantFp(loc, fpAttr)) {
|
||||
if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
@ -732,7 +748,8 @@ LogicalResult Serializer::prepareFloatVectorConstant(
|
||||
return success();
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
|
||||
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
|
||||
bool isSpec) {
|
||||
if (auto id = findConstantID(boolAttr)) {
|
||||
return id;
|
||||
}
|
||||
@ -744,14 +761,18 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
|
||||
}
|
||||
|
||||
auto resultID = getNextID();
|
||||
auto opcode = boolAttr.getValue() ? spirv::Opcode::OpConstantTrue
|
||||
: spirv::Opcode::OpConstantFalse;
|
||||
auto opcode = boolAttr.getValue()
|
||||
? (isSpec ? spirv::Opcode::OpSpecConstantTrue
|
||||
: spirv::Opcode::OpConstantTrue)
|
||||
: (isSpec ? spirv::Opcode::OpSpecConstantFalse
|
||||
: spirv::Opcode::OpConstantFalse);
|
||||
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
|
||||
|
||||
return constIDMap[boolAttr] = resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
|
||||
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
|
||||
bool isSpec) {
|
||||
if (auto id = findConstantID(intAttr)) {
|
||||
return id;
|
||||
}
|
||||
@ -767,6 +788,9 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
|
||||
unsigned bitwidth = value.getBitWidth();
|
||||
bool isSigned = value.isSignedIntN(bitwidth);
|
||||
|
||||
auto opcode =
|
||||
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
|
||||
|
||||
// According to SPIR-V spec, "When the type's bit width is less than 32-bits,
|
||||
// the literal's value appears in the low-order bits of the word, and the
|
||||
// high-order bits must be 0 for a floating-point type, or 0 for an integer
|
||||
@ -778,8 +802,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
|
||||
} else {
|
||||
word = static_cast<uint32_t>(value.getZExtValue());
|
||||
}
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
{typeID, resultID, word});
|
||||
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
|
||||
}
|
||||
// According to SPIR-V spec: "When the type's bit width is larger than one
|
||||
// word, the literal’s low-order words appear first."
|
||||
@ -793,7 +816,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
|
||||
} else {
|
||||
words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
|
||||
}
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
encodeInstructionInto(typesGlobalValues, opcode,
|
||||
{typeID, resultID, words.word1, words.word2});
|
||||
} else {
|
||||
std::string valueStr;
|
||||
@ -808,7 +831,8 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
|
||||
return constIDMap[intAttr] = resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
|
||||
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
|
||||
bool isSpec) {
|
||||
if (auto id = findConstantID(floatAttr)) {
|
||||
return id;
|
||||
}
|
||||
@ -823,22 +847,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
|
||||
APFloat value = floatAttr.getValue();
|
||||
APInt intValue = value.bitcastToAPInt();
|
||||
|
||||
auto opcode =
|
||||
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
|
||||
|
||||
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
|
||||
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
{typeID, resultID, word});
|
||||
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
|
||||
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
|
||||
struct DoubleWord {
|
||||
uint32_t word1;
|
||||
uint32_t word2;
|
||||
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
encodeInstructionInto(typesGlobalValues, opcode,
|
||||
{typeID, resultID, words.word1, words.word2});
|
||||
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
|
||||
uint32_t word =
|
||||
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
{typeID, resultID, word});
|
||||
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
|
||||
} else {
|
||||
std::string valueStr;
|
||||
llvm::raw_string_ostream rss(valueStr);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user