[DRR] Allow interleaved operands and attributes

Previously DRR assumes attributes to appear after operands. This was the
previous requirements on ODS, but that has changed some time ago. Fix
DRR to also support interleaved operands and attributes.

PiperOrigin-RevId: 275983485
Change-Id: I8ba42e442839e3a03d6dbc3f06d1d70a3e06e6fa
This commit is contained in:
Lei Zhang 2019-10-21 20:47:49 -07:00 committed by TensorFlower Gardener
parent 09606bba45
commit 5f0ed20652
4 changed files with 94 additions and 39 deletions

View File

@ -124,6 +124,14 @@ public:
// Returns the total number of arguments. // Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); } int getNumArgs() const { return arguments.size(); }
using arg_iterator = const Argument *;
using arg_range = llvm::iterator_range<arg_iterator>;
// Op argument (attribute or operand) iterators.
arg_iterator arg_begin() const;
arg_iterator arg_end() const;
arg_range getArgs() const;
// Op argument (attribute or operand) accessors. // Op argument (attribute or operand) accessors.
Argument getArg(int index) const; Argument getArg(int index) const;
StringRef getArgName(int index) const; StringRef getArgName(int index) const;

View File

@ -126,6 +126,18 @@ unsigned tblgen::Operator::getNumVariadicOperands() const {
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
} }
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
return arguments.begin();
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const {
return arguments.end();
}
tblgen::Operator::arg_range tblgen::Operator::getArgs() const {
return {arg_begin(), arg_end()};
}
StringRef tblgen::Operator::getArgName(int index) const { StringRef tblgen::Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments"); DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgName(index)->getValue(); return argumentValues->getArgName(index)->getValue();

View File

@ -426,6 +426,28 @@ def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>; def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
def : Pat<(OpJ), (OpK)>; def : Pat<(OpJ), (OpK)>;
def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
let arguments = (ins
I32:$input1,
I64Attr:$attr1,
I32:$input2,
I64Attr:$attr2
);
}
def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> {
let arguments = (ins
I32:$input1,
I64Attr:$attr1,
I32:$input2,
I64Attr:$attr2
);
}
// Test that we can capture and reference interleaved operands and attributes.
def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2),
(OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>;
// Test NativeCodeCall. // Test NativeCodeCall.
def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> { def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> {
let arguments = (ins let arguments = (ins

View File

@ -81,13 +81,13 @@ private:
// `tree`. // `tree`.
void emitOpMatch(DagNode tree, int depth); void emitOpMatch(DagNode tree, int depth);
// Emits C++ statements for matching the `index`-th argument of the given DAG // Emits C++ statements for matching the `argIndex`-th argument of the given
// `tree` as an operand. // DAG `tree` as an operand.
void emitOperandMatch(DagNode tree, int index, int depth, int indent); void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
// Emits C++ statements for matching the `index`-th argument of the given DAG // Emits C++ statements for matching the `argIndex`-th argument of the given
// `tree` as an attribute. // DAG `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent); void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Rewrite utilities // Rewrite utilities
@ -260,11 +260,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
<< '\n'); << '\n');
} }
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
int indent) { int indent) {
Operator &op = tree.getDialectOp(opMap); Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(index); auto matcher = tree.getArgAsLeaf(argIndex);
// If a constraint is specified, we need to generate C++ statements to // If a constraint is specified, we need to generate C++ statements to
// check the constraint. // check the constraint.
@ -272,7 +272,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
if (!matcher.isOperandMatcher()) { if (!matcher.isOperandMatcher()) {
PrintFatalError( PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand", loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), index + 1)); op.getOperationName(), argIndex + 1));
} }
// Only need to verify if the matcher's type is different from the one // Only need to verify if the matcher's type is different from the one
@ -281,12 +281,12 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
if (operand->isVariadic()) { if (operand->isVariadic()) {
auto error = formatv( auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now", "further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), index); op.getOperationName(), argIndex);
PrintFatalError(loc, error); PrintFatalError(loc, error);
} }
auto self = auto self =
formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
depth, index); depth, argIndex);
os.indent(indent) << "if (!(" os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(), << tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf(self)) &fmtCtx.withSelf(self))
@ -295,17 +295,23 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
} }
// Capture the value // Capture the value
auto name = tree.getArgName(index); auto name = tree.getArgName(argIndex);
if (!name.empty()) { if (!name.empty()) {
// We need to subtract the number of attributes before this operand to get
// the index in the operand list.
auto numPrevAttrs = std::count_if(
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
name, depth, index); name, depth, argIndex - numPrevAttrs);
} }
} }
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
int indent) { int indent) {
Operator &op = tree.getDialectOp(opMap); Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
const auto &attr = namedAttr->attr; const auto &attr = namedAttr->attr;
os.indent(indent) << "{\n"; os.indent(indent) << "{\n";
@ -328,12 +334,12 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
} }
auto matcher = tree.getArgAsLeaf(index); auto matcher = tree.getArgAsLeaf(argIndex);
if (!matcher.isUnspecified()) { if (!matcher.isUnspecified()) {
if (!matcher.isAttrMatcher()) { if (!matcher.isAttrMatcher()) {
PrintFatalError( PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an attribute", loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
op.getOperationName(), index + 1)); op.getOperationName(), argIndex + 1));
} }
// If a constraint is specified, we need to generate C++ statements to // If a constraint is specified, we need to generate C++ statements to
@ -345,7 +351,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
} }
// Capture the value // Capture the value
auto name = tree.getArgName(index); auto name = tree.getArgName(argIndex);
if (!name.empty()) { if (!name.empty()) {
os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
} }
@ -683,6 +689,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
int depth) { int depth) {
LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
Operator &resultOp = tree.getDialectOp(opMap); Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs(); auto numOpArgs = resultOp.getNumArgs();
@ -734,12 +744,16 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// * If the operand is variadic, we create a `SmallVector<Value*>` local // * If the operand is variadic, we create a `SmallVector<Value*>` local
// variable. // variable.
int argIndex = 0; // The current index to this op's ODS argument
int valueIndex = 0; // An index for uniquing local variable names. int valueIndex = 0; // An index for uniquing local variable names.
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
const auto &operand = resultOp.getOperand(argIndex); const auto *operand =
resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
if (!operand) {
// We do not need special handling for attributes.
continue;
}
std::string varName; std::string varName;
if (operand.isVariadic()) { if (operand->isVariadic()) {
varName = formatv("tblgen_values_{0}", valueIndex++); varName = formatv("tblgen_values_{0}", valueIndex++);
os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
std::string range; std::string range;
@ -814,22 +828,22 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
os.indent(6) << ", tblgen_types"; os.indent(6) << ", tblgen_types";
} }
// Add operands for the builder all. // Add arguments for the builder call.
for (int i = 0; i < argIndex; ++i) { for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
const auto &operand = resultOp.getOperand(i); // Start each argment on its own line.
// Start each operand on its own line.
(os << ",\n").indent(8); (os << ",\n").indent(8);
if (!operand.name.empty()) {
os << "/*" << operand.name << "=*/";
}
os << childNodeNames[i];
// TODO(jpienaar): verify types
}
// Add attributes for the builder call. Argument opArg = resultOp.getArg(argIndex);
for (; argIndex != numOpArgs; ++argIndex) { // Handle the case of operand first.
// Start each attribute on its own line. if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
(os << ",\n").indent(8); if (!operand->name.empty()) {
os << "/*" << operand->name << "=*/";
}
os << childNodeNames[argIndex];
// TODO(jpienaar): verify types
continue;
}
// The argument in the op definition. // The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex); auto opArgName = resultOp.getArgName(argIndex);
if (auto subTree = tree.getArgAsNestedDag(argIndex)) { if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
@ -844,8 +858,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
auto patArgName = tree.getArgName(argIndex); auto patArgName = tree.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these. // TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(argIndex); if (!opArg.is<NamedAttribute *>())
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
if (!patArgName.empty()) if (!patArgName.empty())
os << "/*" << patArgName << "=*/"; os << "/*" << patArgName << "=*/";