[DRR] Allow interleaved operands and attributes
Previously DRR assumes attributes to appear after operands. This was the previous requirements on ODS, but that has changed some time ago. Fix DRR to also support interleaved operands and attributes. PiperOrigin-RevId: 275983485 Change-Id: I8ba42e442839e3a03d6dbc3f06d1d70a3e06e6fa
This commit is contained in:
parent
09606bba45
commit
5f0ed20652
third_party/mlir
include/mlir/TableGen
lib/TableGen
test/lib/TestDialect
tools/mlir-tblgen
@ -124,6 +124,14 @@ public:
|
||||
// Returns the total number of arguments.
|
||||
int getNumArgs() const { return arguments.size(); }
|
||||
|
||||
using arg_iterator = const Argument *;
|
||||
using arg_range = llvm::iterator_range<arg_iterator>;
|
||||
|
||||
// Op argument (attribute or operand) iterators.
|
||||
arg_iterator arg_begin() const;
|
||||
arg_iterator arg_end() const;
|
||||
arg_range getArgs() const;
|
||||
|
||||
// Op argument (attribute or operand) accessors.
|
||||
Argument getArg(int index) const;
|
||||
StringRef getArgName(int index) const;
|
||||
|
12
third_party/mlir/lib/TableGen/Operator.cpp
vendored
12
third_party/mlir/lib/TableGen/Operator.cpp
vendored
@ -126,6 +126,18 @@ unsigned tblgen::Operator::getNumVariadicOperands() const {
|
||||
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
|
||||
}
|
||||
|
||||
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
|
||||
return arguments.begin();
|
||||
}
|
||||
|
||||
tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const {
|
||||
return arguments.end();
|
||||
}
|
||||
|
||||
tblgen::Operator::arg_range tblgen::Operator::getArgs() const {
|
||||
return {arg_begin(), arg_end()};
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getArgName(int index) const {
|
||||
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||
return argumentValues->getArgName(index)->getValue();
|
||||
|
22
third_party/mlir/test/lib/TestDialect/TestOps.td
vendored
22
third_party/mlir/test/lib/TestDialect/TestOps.td
vendored
@ -426,6 +426,28 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
|
||||
def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
|
||||
def : Pat<(OpJ), (OpK)>;
|
||||
|
||||
def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
|
||||
let arguments = (ins
|
||||
I32:$input1,
|
||||
I64Attr:$attr1,
|
||||
I32:$input2,
|
||||
I64Attr:$attr2
|
||||
);
|
||||
}
|
||||
|
||||
def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> {
|
||||
let arguments = (ins
|
||||
I32:$input1,
|
||||
I64Attr:$attr1,
|
||||
I32:$input2,
|
||||
I64Attr:$attr2
|
||||
);
|
||||
}
|
||||
|
||||
// Test that we can capture and reference interleaved operands and attributes.
|
||||
def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2),
|
||||
(OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>;
|
||||
|
||||
// Test NativeCodeCall.
|
||||
def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> {
|
||||
let arguments = (ins
|
||||
|
@ -81,13 +81,13 @@ private:
|
||||
// `tree`.
|
||||
void emitOpMatch(DagNode tree, int depth);
|
||||
|
||||
// Emits C++ statements for matching the `index`-th argument of the given DAG
|
||||
// `tree` as an operand.
|
||||
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as an operand.
|
||||
void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
|
||||
|
||||
// Emits C++ statements for matching the `index`-th argument of the given DAG
|
||||
// `tree` as an attribute.
|
||||
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as an attribute.
|
||||
void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Rewrite utilities
|
||||
@ -260,11 +260,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
||||
<< '\n');
|
||||
}
|
||||
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
|
||||
int indent) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
|
||||
auto matcher = tree.getArgAsLeaf(index);
|
||||
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
|
||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||
|
||||
// If a constraint is specified, we need to generate C++ statements to
|
||||
// check the constraint.
|
||||
@ -272,7 +272,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
||||
if (!matcher.isOperandMatcher()) {
|
||||
PrintFatalError(
|
||||
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
|
||||
op.getOperationName(), index + 1));
|
||||
op.getOperationName(), argIndex + 1));
|
||||
}
|
||||
|
||||
// Only need to verify if the matcher's type is different from the one
|
||||
@ -281,12 +281,12 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
||||
if (operand->isVariadic()) {
|
||||
auto error = formatv(
|
||||
"further constrain op {0}'s variadic operand #{1} unsupported now",
|
||||
op.getOperationName(), index);
|
||||
op.getOperationName(), argIndex);
|
||||
PrintFatalError(loc, error);
|
||||
}
|
||||
auto self =
|
||||
formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
|
||||
depth, index);
|
||||
depth, argIndex);
|
||||
os.indent(indent) << "if (!("
|
||||
<< tgfmt(matcher.getConditionTemplate(),
|
||||
&fmtCtx.withSelf(self))
|
||||
@ -295,17 +295,23 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
||||
}
|
||||
|
||||
// Capture the value
|
||||
auto name = tree.getArgName(index);
|
||||
auto name = tree.getArgName(argIndex);
|
||||
if (!name.empty()) {
|
||||
// We need to subtract the number of attributes before this operand to get
|
||||
// the index in the operand list.
|
||||
auto numPrevAttrs = std::count_if(
|
||||
op.arg_begin(), op.arg_begin() + argIndex,
|
||||
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
|
||||
|
||||
os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
|
||||
name, depth, index);
|
||||
name, depth, argIndex - numPrevAttrs);
|
||||
}
|
||||
}
|
||||
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
|
||||
int indent) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
|
||||
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
|
||||
const auto &attr = namedAttr->attr;
|
||||
|
||||
os.indent(indent) << "{\n";
|
||||
@ -328,12 +334,12 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
||||
os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
|
||||
}
|
||||
|
||||
auto matcher = tree.getArgAsLeaf(index);
|
||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||
if (!matcher.isUnspecified()) {
|
||||
if (!matcher.isAttrMatcher()) {
|
||||
PrintFatalError(
|
||||
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
|
||||
op.getOperationName(), index + 1));
|
||||
op.getOperationName(), argIndex + 1));
|
||||
}
|
||||
|
||||
// If a constraint is specified, we need to generate C++ statements to
|
||||
@ -345,7 +351,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
||||
}
|
||||
|
||||
// Capture the value
|
||||
auto name = tree.getArgName(index);
|
||||
auto name = tree.getArgName(argIndex);
|
||||
if (!name.empty()) {
|
||||
os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
|
||||
}
|
||||
@ -683,6 +689,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
|
||||
|
||||
std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||
int depth) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
|
||||
LLVM_DEBUG(tree.print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
||||
Operator &resultOp = tree.getDialectOp(opMap);
|
||||
auto numOpArgs = resultOp.getNumArgs();
|
||||
|
||||
@ -734,12 +744,16 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||
// * If the operand is variadic, we create a `SmallVector<Value*>` local
|
||||
// variable.
|
||||
|
||||
int argIndex = 0; // The current index to this op's ODS argument
|
||||
int valueIndex = 0; // An index for uniquing local variable names.
|
||||
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
|
||||
const auto &operand = resultOp.getOperand(argIndex);
|
||||
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
|
||||
const auto *operand =
|
||||
resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
|
||||
if (!operand) {
|
||||
// We do not need special handling for attributes.
|
||||
continue;
|
||||
}
|
||||
std::string varName;
|
||||
if (operand.isVariadic()) {
|
||||
if (operand->isVariadic()) {
|
||||
varName = formatv("tblgen_values_{0}", valueIndex++);
|
||||
os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
|
||||
std::string range;
|
||||
@ -814,22 +828,22 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||
os.indent(6) << ", tblgen_types";
|
||||
}
|
||||
|
||||
// Add operands for the builder all.
|
||||
for (int i = 0; i < argIndex; ++i) {
|
||||
const auto &operand = resultOp.getOperand(i);
|
||||
// Start each operand on its own line.
|
||||
// Add arguments for the builder call.
|
||||
for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
|
||||
// Start each argment on its own line.
|
||||
(os << ",\n").indent(8);
|
||||
if (!operand.name.empty()) {
|
||||
os << "/*" << operand.name << "=*/";
|
||||
}
|
||||
os << childNodeNames[i];
|
||||
// TODO(jpienaar): verify types
|
||||
}
|
||||
|
||||
// Add attributes for the builder call.
|
||||
for (; argIndex != numOpArgs; ++argIndex) {
|
||||
// Start each attribute on its own line.
|
||||
(os << ",\n").indent(8);
|
||||
Argument opArg = resultOp.getArg(argIndex);
|
||||
// Handle the case of operand first.
|
||||
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
|
||||
if (!operand->name.empty()) {
|
||||
os << "/*" << operand->name << "=*/";
|
||||
}
|
||||
os << childNodeNames[argIndex];
|
||||
// TODO(jpienaar): verify types
|
||||
continue;
|
||||
}
|
||||
|
||||
// The argument in the op definition.
|
||||
auto opArgName = resultOp.getArgName(argIndex);
|
||||
if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
|
||||
@ -844,8 +858,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
||||
auto patArgName = tree.getArgName(argIndex);
|
||||
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(argIndex);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
if (!opArg.is<NamedAttribute *>())
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
|
Loading…
Reference in New Issue
Block a user