Updated for changes in LLVM 7a45aeacf3a2.
- The deprecated CreateCall(Value*, ...) IRBuilder API has been removed. - Renamed applyPatternsGreedily to applyPatternsAndFoldGreedily in MLIR. - Update MLIR users after adding support for optional operands/results to ODS (upstream aba1acc89c653b2cc08cccfb754ff16994a05332) - Other updates to BUILD files for upstream changes. PiperOrigin-RevId: 306177884 Change-Id: Idae1009ba89caf296758748ab7aa57815d946a0c
This commit is contained in:
parent
700ff4897f
commit
f5ee735428
@ -496,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (value.isVariableLength()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
@ -504,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (value.isVariableLength()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ void LegalizeTFToQuant::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto *ctx = func.getContext();
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -30,7 +30,7 @@ void IdentifyDilatedConvPass::runOnFunction() {
|
||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
|
||||
&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -711,7 +711,7 @@ void Optimize::runOnFunction() {
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
||||
FuseFullyConnectedAndMul>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
// Fuse the binary ops with the following ops.
|
||||
patterns.insert<
|
||||
@ -719,7 +719,7 @@ void Optimize::runOnFunction() {
|
||||
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
|
||||
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
|
||||
ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -187,7 +187,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
|
||||
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
||||
|
||||
ModuleOp module = getOperation();
|
||||
applyPatternsGreedily(module, patterns);
|
||||
applyPatternsAndFoldGreedily(module, patterns);
|
||||
|
||||
// Erase inlined functions that don't have any references.
|
||||
//
|
||||
|
@ -125,7 +125,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
RemoveQuantizationAdaptorOps(getFunction());
|
||||
|
@ -267,7 +267,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
||||
}
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
SanityCheckAndAdjustment(func);
|
||||
|
||||
|
@ -619,8 +619,8 @@ void PrepareTFPass::runOnFunction() {
|
||||
|
||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
// TF FakeQuant ops by the constant folding.
|
||||
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
|
||||
// the TF FakeQuant ops by the constant folding.
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
|
||||
// This pattern will try to identify and optimize for dilated convolution.
|
||||
@ -634,7 +634,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
// This will allow optimizing any TF_Mul->TF_Conv in the graph
|
||||
// and any expanded from FusedBatchNorm. We need to do this
|
||||
// before converting TF_Conv to TFL_Conv
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
// Load the generated pattern again, so new quantization pass-through
|
||||
// will be applied.
|
||||
@ -646,7 +646,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
}
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -88,7 +88,7 @@ void QuantizePass::runOnFunction() {
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<TFLFullQuantization>(
|
||||
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -55,7 +55,7 @@ void BatchMatMulToEinsumPass::runOnFunction() {
|
||||
patterns.insert<ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulV2Op>>(
|
||||
&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -45,7 +45,7 @@ struct DecomposeResourceOps
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -364,7 +364,7 @@ void TransformEinsumPass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFEinsumOp>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<TransformEinsumPass> pass(
|
||||
|
@ -118,7 +118,7 @@ void GpuOpFusionPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ReluToFusedBatchNorm>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -29,7 +29,7 @@ struct LowerTF : public PassWrapper<LowerTF, FunctionPass> {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::TF::PopulateLoweringTFPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -38,7 +38,7 @@ struct TFOptimizePass : public PassWrapper<TFOptimizePass, FunctionPass> {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -55,7 +55,7 @@ void UnrollBatchMatMulPass::runOnFunction() {
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -81,7 +81,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
// Emit an argument for an operand.
|
||||
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
|
||||
// Handle a non-variadic operand.
|
||||
if (!operand_cst->isVariadic()) {
|
||||
if (!operand_cst->isVariableLength()) {
|
||||
os << " auto xla_arg_" << index
|
||||
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
|
||||
<< ").begin()];\n";
|
||||
@ -108,7 +108,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
|
||||
// If all operands are variadic, then pass the builder explicitly to xla
|
||||
// client API call
|
||||
if (op.getNumOperands() == op.getNumVariadicOperands()) {
|
||||
if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
|
||||
os << "lowering_context.builder";
|
||||
if (op.getNumArgs() != 0) os << ", ";
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||
void LegalizeToStandard::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||
|
@ -87,7 +87,7 @@ struct LhloLegalizeToAffine
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -71,7 +71,7 @@ void LowerComplex::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LowerComplex> pass(
|
||||
|
@ -178,7 +178,7 @@ struct LegalizeGeneralDot
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
||||
&getContext());
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -32,7 +32,7 @@ struct TestUnfuseBatchNormPass
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||
applyPatternsGreedily(getOperation(), patterns);
|
||||
applyPatternsAndFoldGreedily(getOperation(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -79,11 +79,31 @@ class IrBuilderMixin {
|
||||
return mixin_builder()->CreateBr(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
llvm::CallInst* Call(llvm::FunctionCallee func_callee,
|
||||
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
||||
const llvm::Twine& name = "",
|
||||
llvm::MDNode* fp_math_tag = nullptr) {
|
||||
return mixin_builder()->CreateCall(func_callee, args, name, fp_math_tag);
|
||||
}
|
||||
|
||||
llvm::CallInst* Call(llvm::FunctionType* func_type, llvm::Value* callee,
|
||||
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
||||
const llvm::Twine& name = "",
|
||||
llvm::MDNode* fp_math_tag = nullptr) {
|
||||
return mixin_builder()->CreateCall(func_type, callee, args, name,
|
||||
fp_math_tag);
|
||||
}
|
||||
|
||||
// DEPRECATED. LLVM is removing getPointerElementType, so calls to this must
|
||||
// be transitioned to one of the other overloads.
|
||||
llvm::CallInst* Call(llvm::Value* callee,
|
||||
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
||||
const llvm::Twine& name = "",
|
||||
llvm::MDNode* fp_math_tag = nullptr) {
|
||||
return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
|
||||
return mixin_builder()->CreateCall(
|
||||
llvm::cast<llvm::FunctionType>(
|
||||
callee->getType()->getPointerElementType()),
|
||||
callee, args, name, fp_math_tag);
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
|
14
third_party/mlir/BUILD
vendored
14
third_party/mlir/BUILD
vendored
@ -2666,6 +2666,20 @@ cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "mlir-linalg-ods-gen",
|
||||
srcs = glob([
|
||||
"tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp",
|
||||
]),
|
||||
deps = [
|
||||
":IR",
|
||||
":Support",
|
||||
"@llvm-project//llvm:config",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
],
|
||||
)
|
||||
|
||||
## OpenMP dialect
|
||||
gentbl(
|
||||
name = "OpenMPOpsIncGen",
|
||||
|
Loading…
Reference in New Issue
Block a user