Add type inference variant for separate params builder generated

Add variant that does invoke infer type op interface where defined. Also add entry function that invokes that different separate argument builders for wrapped, unwrapped and inference variant.

PiperOrigin-RevId: 285220709
Change-Id: I5de5702feece344bc7edfeb2147691a02ce1865f
This commit is contained in:
Jacques Pienaar 2019-12-12 10:35:40 -08:00 committed by TensorFlower Gardener
parent 2f28e151b3
commit d45163cfe0
2 changed files with 102 additions and 93 deletions

View File

@ -290,7 +290,7 @@ class. See [Constraints](#constraints) for more information.
### Operation interfaces
[Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by
which to opaquely call methods and access information on an *Op instance,
which to opaquely call methods and access information on an *Op instance*,
without knowing the exact operation type. Operation interfaces defined in C++
can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside
from using pre-existing interfaces in the C++ API, the ODS framework also
@ -414,7 +414,7 @@ The following builders are generated:
// All result-types/operands/attributes have one aggregate parameter.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
ArrayRef<Type> resultTypes,
ArrayRef<Value> operands,
ValueRange operands,
ArrayRef<NamedAttribute> attributes);
// Each result-type/operand/attribute has a separate parameter. The parameters
@ -433,7 +433,19 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state,
Value *i32_operand, Value *f32_operand, ...,
APInt i32_attr, StringRef f32_attr, ...);
// (And potentially others depending on the specific op.)
// Each operand/attribute has a separate parameter but result type is aggregate.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
ArrayRef<Type> resultTypes,
Value *i32_operand, Value *f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
// All operands/attributes have aggregate parameters.
// Generated if InferTypeOpInterface interface is specified.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
ValueRange operands,
ArrayRef<NamedAttribute> attributes);
// (And manually specified builders depending on the specific op.)
```
The first form provides basic uniformity so that we can create ops using the

View File

@ -514,21 +514,9 @@ private:
// Generates builder methods for the operation.
void genBuilder();
// Generates the build() method that takes each result-type/operand/attribute
// as a stand-alone parameter. Attributes will take wrapped mlir::Attribute
// values. The generated build() method also requires specifying result types
// for all results.
void genSeparateParamWrappedAttrBuilder();
// Generates the build() method that takes each result-type/operand/attribute
// as a stand-alone parameter. Attributes will take raw values without
// mlir::Attribute wrapper. The generated build() method also requires
// specifying result types for all results.
void genSeparateParamUnwrappedAttrBuilder();
// Generates the build() method that takes a single parameter for all the
// result types and a separate parameter for each operand/attribute.
void genCollectiveTypeParamBuilder();
// Generates the build() method that takes each operand/attribute
// as a stand-alone parameter.
void genSeparateArgParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's
@ -897,26 +885,11 @@ void OpEmitter::genNamedRegionGetters() {
}
}
void OpEmitter::genSeparateParamWrappedAttrBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::Separate);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
// Push all result types to the operation state
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
m.body() << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
}
void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
static bool canGenerateUnwrappedBuilder(Operator &op) {
// If this op does not have native attributes at all, return directly to avoid
// redefining builders.
if (op.getNumNativeAttributes() == 0)
return;
return false;
bool canGenerate = false;
// We are generating builders that take raw values for attributes. We need to
@ -930,47 +903,75 @@ void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
break;
}
}
if (!canGenerate)
return;
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::Separate,
AttrParamKind::UnwrappedValue);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true);
// Push all result types to the operation state.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
m.body() << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
return canGenerate;
}
void OpEmitter::genCollectiveTypeParamBuilder() {
auto numResults = op.getNumResults();
void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
if (canGenerateUnwrappedBuilder(op))
attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
// If this op has no results, then just skip generating this builder.
// Otherwise we are generating the same signature as the separate-parameter
// builder.
if (numResults == 0)
return;
// Emit with separate builders with or without unwrapped attributes and/or
// inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType);
// Similarly for ops with one single variadic result, which will also have one
// `ArrayRef<Type>` parameter for the result type.
if (numResults == 1 && op.getResult(0).isVariadic())
return;
auto &m =
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
auto &body = m.body();
genCodeForAddingArgAndRegionForBuilder(
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::Collective);
// Push all result types to the operation state
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
if (inferType) {
// Generate builder that infers type too.
// TODO(jpienaar): Subsume this with general checking if type can be
// infered automatically.
// TODO(jpienaar): Expand to handle regions.
body << formatv(R"(
SmallVector<Type, 2> inferedReturnTypes;
if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands,
{1}.attributes, /*regions=*/{{}, inferedReturnTypes)))
{1}.addTypes(inferedReturnTypes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
return;
}
// Push all result types to the operation state
m.body() << formatv(" {0}.addTypes(resultTypes);\n", builderOpState);
switch (paramKind) {
case TypeParamKind::None:
return;
case TypeParamKind::Separate:
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
body << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
return;
case TypeParamKind::Collective:
body << " " << builderOpState << ".addTypes(resultTypes);\n";
return;
};
llvm_unreachable("unhandled TypeParamKind");
};
bool canInferType =
op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType)
emit(attrType, TypeParamKind::None, /*inferType=*/true);
// Emit separate arg build with collective type, unless there is only one
// variadic result, in which case the above would have already generated
// the same build method.
if (op.getNumResults() == 1 && !op.getResult(0).isVariadic())
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
@ -1021,8 +1022,7 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() {
/*regions=*/{{}, inferedReturnTypes)))
build(builder, tblgen_state, inferedReturnTypes, operands, attributes);
else
llvm::report_fatal_error("Failed to infer result type(s).");
)",
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
}
@ -1111,18 +1111,13 @@ void OpEmitter::genBuilder() {
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate three builders here:
// 1. one having a stand-alone parameter for each result type / operand /
// attribute, and
genSeparateParamWrappedAttrBuilder();
genSeparateParamUnwrappedAttrBuilder();
// 2. one having a stand-alone parameter for each operand / attribute and
// an aggregated parameter for all result types, and
genCollectiveTypeParamBuilder();
// 3. one having an aggregated parameter for all result types / operands /
// We generate three classes of builders here:
// 1. one having a stand-alone parameter for each operand / attribute, and
genSeparateArgParamBuilder();
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
genCollectiveParamBuilder();
// 4. one having a stand-alone parameter for each operand and attribute,
// 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariadicResults() == 0) {
@ -1133,11 +1128,6 @@ void OpEmitter::genBuilder() {
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
}
// TODO(jpienaar): Subsume this with general checking if type can be infered
// automatically.
// TODO(jpienaar): Expand to handle regions.
if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
genInferedTypeCollectiveParamBuilder();
}
void OpEmitter::genCollectiveParamBuilder() {
@ -1156,13 +1146,6 @@ void OpEmitter::genCollectiveParamBuilder() {
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
// Result types
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
@ -1179,6 +1162,20 @@ void OpEmitter::genCollectiveParamBuilder() {
for (int i = 0; i < numRegions; ++i)
m.body() << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
// Generate builder that infers type too.
// TODO(jpienaar): Subsume this with general checking if type can be infered
// automatically.
// TODO(jpienaar): Expand to handle regions.
if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
genInferedTypeCollectiveParamBuilder();
}
void OpEmitter::buildParamList(std::string &paramList,