Add type inference variant for separate params builder generated
Add variant that does invoke infer type op interface where defined. Also add entry function that invokes that different separate argument builders for wrapped, unwrapped and inference variant. PiperOrigin-RevId: 285220709 Change-Id: I5de5702feece344bc7edfeb2147691a02ce1865f
This commit is contained in:
parent
2f28e151b3
commit
d45163cfe0
18
third_party/mlir/g3doc/OpDefinitions.md
vendored
18
third_party/mlir/g3doc/OpDefinitions.md
vendored
@ -290,7 +290,7 @@ class. See [Constraints](#constraints) for more information.
|
||||
### Operation interfaces
|
||||
|
||||
[Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by
|
||||
which to opaquely call methods and access information on an *Op instance,
|
||||
which to opaquely call methods and access information on an *Op instance*,
|
||||
without knowing the exact operation type. Operation interfaces defined in C++
|
||||
can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside
|
||||
from using pre-existing interfaces in the C++ API, the ODS framework also
|
||||
@ -414,7 +414,7 @@ The following builders are generated:
|
||||
// All result-types/operands/attributes have one aggregate parameter.
|
||||
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<Value> operands,
|
||||
ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes);
|
||||
|
||||
// Each result-type/operand/attribute has a separate parameter. The parameters
|
||||
@ -433,7 +433,19 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state,
|
||||
Value *i32_operand, Value *f32_operand, ...,
|
||||
APInt i32_attr, StringRef f32_attr, ...);
|
||||
|
||||
// (And potentially others depending on the specific op.)
|
||||
// Each operand/attribute has a separate parameter but result type is aggregate.
|
||||
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
|
||||
ArrayRef<Type> resultTypes,
|
||||
Value *i32_operand, Value *f32_operand, ...,
|
||||
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
|
||||
|
||||
// All operands/attributes have aggregate parameters.
|
||||
// Generated if InferTypeOpInterface interface is specified.
|
||||
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
|
||||
ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes);
|
||||
|
||||
// (And manually specified builders depending on the specific op.)
|
||||
```
|
||||
|
||||
The first form provides basic uniformity so that we can create ops using the
|
||||
|
@ -514,21 +514,9 @@ private:
|
||||
// Generates builder methods for the operation.
|
||||
void genBuilder();
|
||||
|
||||
// Generates the build() method that takes each result-type/operand/attribute
|
||||
// as a stand-alone parameter. Attributes will take wrapped mlir::Attribute
|
||||
// values. The generated build() method also requires specifying result types
|
||||
// for all results.
|
||||
void genSeparateParamWrappedAttrBuilder();
|
||||
|
||||
// Generates the build() method that takes each result-type/operand/attribute
|
||||
// as a stand-alone parameter. Attributes will take raw values without
|
||||
// mlir::Attribute wrapper. The generated build() method also requires
|
||||
// specifying result types for all results.
|
||||
void genSeparateParamUnwrappedAttrBuilder();
|
||||
|
||||
// Generates the build() method that takes a single parameter for all the
|
||||
// result types and a separate parameter for each operand/attribute.
|
||||
void genCollectiveTypeParamBuilder();
|
||||
// Generates the build() method that takes each operand/attribute
|
||||
// as a stand-alone parameter.
|
||||
void genSeparateArgParamBuilder();
|
||||
|
||||
// Generates the build() method that takes each operand/attribute as a
|
||||
// stand-alone parameter. The generated build() method uses first operand's
|
||||
@ -897,26 +885,11 @@ void OpEmitter::genNamedRegionGetters() {
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genSeparateParamWrappedAttrBuilder() {
|
||||
std::string paramList;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
buildParamList(paramList, resultNames, TypeParamKind::Separate);
|
||||
|
||||
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
|
||||
genCodeForAddingArgAndRegionForBuilder(m.body());
|
||||
|
||||
// Push all result types to the operation state
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
m.body() << " " << builderOpState << ".addTypes(" << resultNames[i]
|
||||
<< ");\n";
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
|
||||
static bool canGenerateUnwrappedBuilder(Operator &op) {
|
||||
// If this op does not have native attributes at all, return directly to avoid
|
||||
// redefining builders.
|
||||
if (op.getNumNativeAttributes() == 0)
|
||||
return;
|
||||
return false;
|
||||
|
||||
bool canGenerate = false;
|
||||
// We are generating builders that take raw values for attributes. We need to
|
||||
@ -930,47 +903,75 @@ void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!canGenerate)
|
||||
return;
|
||||
|
||||
std::string paramList;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
buildParamList(paramList, resultNames, TypeParamKind::Separate,
|
||||
AttrParamKind::UnwrappedValue);
|
||||
|
||||
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
|
||||
genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true);
|
||||
|
||||
// Push all result types to the operation state.
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
m.body() << " " << builderOpState << ".addTypes(" << resultNames[i]
|
||||
<< ");\n";
|
||||
}
|
||||
return canGenerate;
|
||||
}
|
||||
|
||||
void OpEmitter::genCollectiveTypeParamBuilder() {
|
||||
auto numResults = op.getNumResults();
|
||||
void OpEmitter::genSeparateArgParamBuilder() {
|
||||
SmallVector<AttrParamKind, 2> attrBuilderType;
|
||||
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
|
||||
if (canGenerateUnwrappedBuilder(op))
|
||||
attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
|
||||
|
||||
// If this op has no results, then just skip generating this builder.
|
||||
// Otherwise we are generating the same signature as the separate-parameter
|
||||
// builder.
|
||||
if (numResults == 0)
|
||||
return;
|
||||
// Emit with separate builders with or without unwrapped attributes and/or
|
||||
// inferring result type.
|
||||
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
|
||||
bool inferType) {
|
||||
std::string paramList;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
buildParamList(paramList, resultNames, paramKind, attrType);
|
||||
|
||||
// Similarly for ops with one single variadic result, which will also have one
|
||||
// `ArrayRef<Type>` parameter for the result type.
|
||||
if (numResults == 1 && op.getResult(0).isVariadic())
|
||||
return;
|
||||
auto &m =
|
||||
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
|
||||
auto &body = m.body();
|
||||
genCodeForAddingArgAndRegionForBuilder(
|
||||
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
|
||||
|
||||
std::string paramList;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
buildParamList(paramList, resultNames, TypeParamKind::Collective);
|
||||
// Push all result types to the operation state
|
||||
|
||||
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
|
||||
genCodeForAddingArgAndRegionForBuilder(m.body());
|
||||
if (inferType) {
|
||||
// Generate builder that infers type too.
|
||||
// TODO(jpienaar): Subsume this with general checking if type can be
|
||||
// infered automatically.
|
||||
// TODO(jpienaar): Expand to handle regions.
|
||||
body << formatv(R"(
|
||||
SmallVector<Type, 2> inferedReturnTypes;
|
||||
if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands,
|
||||
{1}.attributes, /*regions=*/{{}, inferedReturnTypes)))
|
||||
{1}.addTypes(inferedReturnTypes);
|
||||
else
|
||||
llvm::report_fatal_error("Failed to infer result type(s).");)",
|
||||
opClass.getClassName(), builderOpState);
|
||||
return;
|
||||
}
|
||||
|
||||
// Push all result types to the operation state
|
||||
m.body() << formatv(" {0}.addTypes(resultTypes);\n", builderOpState);
|
||||
switch (paramKind) {
|
||||
case TypeParamKind::None:
|
||||
return;
|
||||
case TypeParamKind::Separate:
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
body << " " << builderOpState << ".addTypes(" << resultNames[i]
|
||||
<< ");\n";
|
||||
}
|
||||
return;
|
||||
case TypeParamKind::Collective:
|
||||
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
||||
return;
|
||||
};
|
||||
llvm_unreachable("unhandled TypeParamKind");
|
||||
};
|
||||
|
||||
bool canInferType =
|
||||
op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
|
||||
for (auto attrType : attrBuilderType) {
|
||||
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
|
||||
if (canInferType)
|
||||
emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
||||
// Emit separate arg build with collective type, unless there is only one
|
||||
// variadic result, in which case the above would have already generated
|
||||
// the same build method.
|
||||
if (op.getNumResults() == 1 && !op.getResult(0).isVariadic())
|
||||
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
||||
@ -1021,8 +1022,7 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() {
|
||||
/*regions=*/{{}, inferedReturnTypes)))
|
||||
build(builder, tblgen_state, inferedReturnTypes, operands, attributes);
|
||||
else
|
||||
llvm::report_fatal_error("Failed to infer result type(s).");
|
||||
)",
|
||||
llvm::report_fatal_error("Failed to infer result type(s).");)",
|
||||
opClass.getClassName(), builderOpState);
|
||||
}
|
||||
|
||||
@ -1111,18 +1111,13 @@ void OpEmitter::genBuilder() {
|
||||
// Generate default builders that requires all result type, operands, and
|
||||
// attributes as parameters.
|
||||
|
||||
// We generate three builders here:
|
||||
// 1. one having a stand-alone parameter for each result type / operand /
|
||||
// attribute, and
|
||||
genSeparateParamWrappedAttrBuilder();
|
||||
genSeparateParamUnwrappedAttrBuilder();
|
||||
// 2. one having a stand-alone parameter for each operand / attribute and
|
||||
// an aggregated parameter for all result types, and
|
||||
genCollectiveTypeParamBuilder();
|
||||
// 3. one having an aggregated parameter for all result types / operands /
|
||||
// We generate three classes of builders here:
|
||||
// 1. one having a stand-alone parameter for each operand / attribute, and
|
||||
genSeparateArgParamBuilder();
|
||||
// 2. one having an aggregated parameter for all result types / operands /
|
||||
// attributes, and
|
||||
genCollectiveParamBuilder();
|
||||
// 4. one having a stand-alone parameter for each operand and attribute,
|
||||
// 3. one having a stand-alone parameter for each operand and attribute,
|
||||
// use the first operand or attribute's type as all result types
|
||||
// to facilitate different call patterns.
|
||||
if (op.getNumVariadicResults() == 0) {
|
||||
@ -1133,11 +1128,6 @@ void OpEmitter::genBuilder() {
|
||||
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
|
||||
genUseAttrAsResultTypeBuilder();
|
||||
}
|
||||
// TODO(jpienaar): Subsume this with general checking if type can be infered
|
||||
// automatically.
|
||||
// TODO(jpienaar): Expand to handle regions.
|
||||
if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
|
||||
genInferedTypeCollectiveParamBuilder();
|
||||
}
|
||||
|
||||
void OpEmitter::genCollectiveParamBuilder() {
|
||||
@ -1156,13 +1146,6 @@ void OpEmitter::genCollectiveParamBuilder() {
|
||||
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
|
||||
auto &body = m.body();
|
||||
|
||||
// Result types
|
||||
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
||||
body << " assert(resultTypes.size()"
|
||||
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
||||
<< "u && \"mismatched number of return types\");\n";
|
||||
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
||||
|
||||
// Operands
|
||||
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
|
||||
body << " assert(operands.size()"
|
||||
@ -1179,6 +1162,20 @@ void OpEmitter::genCollectiveParamBuilder() {
|
||||
for (int i = 0; i < numRegions; ++i)
|
||||
m.body() << " (void)" << builderOpState << ".addRegion();\n";
|
||||
}
|
||||
|
||||
// Result types
|
||||
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
||||
body << " assert(resultTypes.size()"
|
||||
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
||||
<< "u && \"mismatched number of return types\");\n";
|
||||
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
||||
|
||||
// Generate builder that infers type too.
|
||||
// TODO(jpienaar): Subsume this with general checking if type can be infered
|
||||
// automatically.
|
||||
// TODO(jpienaar): Expand to handle regions.
|
||||
if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
|
||||
genInferedTypeCollectiveParamBuilder();
|
||||
}
|
||||
|
||||
void OpEmitter::buildParamList(std::string ¶mList,
|
||||
|
Loading…
Reference in New Issue
Block a user