[DRR] Allow interleaved operands and attributes

Previously DRR assumes attributes to appear after operands. This was the
previous requirements on ODS, but that has changed some time ago. Fix
DRR to also support interleaved operands and attributes.

PiperOrigin-RevId: 275983485
Change-Id: I8ba42e442839e3a03d6dbc3f06d1d70a3e06e6fa
This commit is contained in:
Lei Zhang 2019-10-21 20:47:49 -07:00 committed by TensorFlower Gardener
parent 09606bba45
commit 5f0ed20652
4 changed files with 94 additions and 39 deletions
third_party/mlir
include/mlir/TableGen
lib/TableGen
test/lib/TestDialect
tools/mlir-tblgen

View File

@ -124,6 +124,14 @@ public:
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }
using arg_iterator = const Argument *;
using arg_range = llvm::iterator_range<arg_iterator>;
// Op argument (attribute or operand) iterators.
arg_iterator arg_begin() const;
arg_iterator arg_end() const;
arg_range getArgs() const;
// Op argument (attribute or operand) accessors.
Argument getArg(int index) const;
StringRef getArgName(int index) const;

View File

@ -126,6 +126,18 @@ unsigned tblgen::Operator::getNumVariadicOperands() const {
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
return arguments.begin();
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const {
return arguments.end();
}
tblgen::Operator::arg_range tblgen::Operator::getArgs() const {
return {arg_begin(), arg_end()};
}
StringRef tblgen::Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgName(index)->getValue();

View File

@ -426,6 +426,28 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
def : Pat<(OpJ), (OpK)>;
def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
let arguments = (ins
I32:$input1,
I64Attr:$attr1,
I32:$input2,
I64Attr:$attr2
);
}
def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> {
let arguments = (ins
I32:$input1,
I64Attr:$attr1,
I32:$input2,
I64Attr:$attr2
);
}
// Test that we can capture and reference interleaved operands and attributes.
def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2),
(OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>;
// Test NativeCodeCall.
def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> {
let arguments = (ins

View File

@ -81,13 +81,13 @@ private:
// `tree`.
void emitOpMatch(DagNode tree, int depth);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an operand.
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand.
void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
//===--------------------------------------------------------------------===//
// Rewrite utilities
@ -260,11 +260,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
<< '\n');
}
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(index);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(argIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
@ -272,7 +272,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
if (!matcher.isOperandMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), index + 1));
op.getOperationName(), argIndex + 1));
}
// Only need to verify if the matcher's type is different from the one
@ -281,12 +281,12 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
if (operand->isVariadic()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), index);
op.getOperationName(), argIndex);
PrintFatalError(loc, error);
}
auto self =
formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
depth, index);
depth, argIndex);
os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf(self))
@ -295,17 +295,23 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
}
// Capture the value
auto name = tree.getArgName(index);
auto name = tree.getArgName(argIndex);
if (!name.empty()) {
// We need to subtract the number of attributes before this operand to get
// the index in the operand list.
auto numPrevAttrs = std::count_if(
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
name, depth, index);
name, depth, argIndex - numPrevAttrs);
}
}
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
const auto &attr = namedAttr->attr;
os.indent(indent) << "{\n";
@ -328,12 +334,12 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
}
auto matcher = tree.getArgAsLeaf(index);
auto matcher = tree.getArgAsLeaf(argIndex);
if (!matcher.isUnspecified()) {
if (!matcher.isAttrMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
op.getOperationName(), index + 1));
op.getOperationName(), argIndex + 1));
}
// If a constraint is specified, we need to generate C++ statements to
@ -345,7 +351,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
}
// Capture the value
auto name = tree.getArgName(index);
auto name = tree.getArgName(argIndex);
if (!name.empty()) {
os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
}
@ -683,6 +689,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
int depth) {
LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs();
@ -734,12 +744,16 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// * If the operand is variadic, we create a `SmallVector<Value*>` local
// variable.
int argIndex = 0; // The current index to this op's ODS argument
int valueIndex = 0; // An index for uniquing local variable names.
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
const auto &operand = resultOp.getOperand(argIndex);
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
const auto *operand =
resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
if (!operand) {
// We do not need special handling for attributes.
continue;
}
std::string varName;
if (operand.isVariadic()) {
if (operand->isVariadic()) {
varName = formatv("tblgen_values_{0}", valueIndex++);
os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
std::string range;
@ -814,22 +828,22 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
os.indent(6) << ", tblgen_types";
}
// Add operands for the builder all.
for (int i = 0; i < argIndex; ++i) {
const auto &operand = resultOp.getOperand(i);
// Start each operand on its own line.
// Add arguments for the builder call.
for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
// Start each argment on its own line.
(os << ",\n").indent(8);
if (!operand.name.empty()) {
os << "/*" << operand.name << "=*/";
}
os << childNodeNames[i];
// TODO(jpienaar): verify types
}
// Add attributes for the builder call.
for (; argIndex != numOpArgs; ++argIndex) {
// Start each attribute on its own line.
(os << ",\n").indent(8);
Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
if (!operand->name.empty()) {
os << "/*" << operand->name << "=*/";
}
os << childNodeNames[argIndex];
// TODO(jpienaar): verify types
continue;
}
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
@ -844,8 +858,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
auto patArgName = tree.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(argIndex);
if (!argument.is<NamedAttribute *>())
if (!opArg.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";