Updated for changes in LLVM 7a45aeacf3a2.
- The deprecated CreateCall(Value*, ...) IRBuilder API has been removed. - Renamed applyPatternsGreedily to applyPatternsAndFoldGreedily in MLIR. - Update MLIR users after adding support for optional operands/results to ODS (upstream aba1acc89c653b2cc08cccfb754ff16994a05332) - Other updates to BUILD files for upstream changes. PiperOrigin-RevId: 306177884 Change-Id: Idae1009ba89caf296758748ab7aa57815d946a0c
This commit is contained in:
parent
700ff4897f
commit
f5ee735428
@ -496,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||||||
auto &value = op.getOperand(i);
|
auto &value = op.getOperand(i);
|
||||||
// Skip from from first variadic operands for now. Else getOperand index
|
// Skip from from first variadic operands for now. Else getOperand index
|
||||||
// used below doesn't match.
|
// used below doesn't match.
|
||||||
if (value.isVariadic()) break;
|
if (value.isVariableLength()) break;
|
||||||
if (!value.name.empty())
|
if (!value.name.empty())
|
||||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||||
}
|
}
|
||||||
@ -504,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||||||
auto &value = op.getResult(i);
|
auto &value = op.getResult(i);
|
||||||
// Skip from from first variadic results for now. Else getResult index
|
// Skip from from first variadic results for now. Else getResult index
|
||||||
// used below doesn't match.
|
// used below doesn't match.
|
||||||
if (value.isVariadic()) break;
|
if (value.isVariableLength()) break;
|
||||||
if (!value.name.empty())
|
if (!value.name.empty())
|
||||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||||
}
|
}
|
||||||
|
@ -146,7 +146,7 @@ void LegalizeTFToQuant::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto *ctx = func.getContext();
|
auto *ctx = func.getContext();
|
||||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ void IdentifyDilatedConvPass::runOnFunction() {
|
|||||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
||||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
|
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
|
||||||
&getContext());
|
&getContext());
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -711,7 +711,7 @@ void Optimize::runOnFunction() {
|
|||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
||||||
FuseFullyConnectedAndMul>(ctx);
|
FuseFullyConnectedAndMul>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
// Fuse the binary ops with the following ops.
|
// Fuse the binary ops with the following ops.
|
||||||
patterns.insert<
|
patterns.insert<
|
||||||
@ -719,7 +719,7 @@ void Optimize::runOnFunction() {
|
|||||||
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
|
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
|
||||||
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
|
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
|
||||||
ctx);
|
ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -187,7 +187,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
|
|||||||
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
||||||
|
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
applyPatternsGreedily(module, patterns);
|
applyPatternsAndFoldGreedily(module, patterns);
|
||||||
|
|
||||||
// Erase inlined functions that don't have any references.
|
// Erase inlined functions that don't have any references.
|
||||||
//
|
//
|
||||||
|
@ -125,7 +125,7 @@ void PostQuantizePass::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto* ctx = func.getContext();
|
auto* ctx = func.getContext();
|
||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
if (!emit_quant_adaptor_ops_) {
|
if (!emit_quant_adaptor_ops_) {
|
||||||
RemoveQuantizationAdaptorOps(getFunction());
|
RemoveQuantizationAdaptorOps(getFunction());
|
||||||
|
@ -267,7 +267,7 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
// Currently, only activation stats are imported, so narrow_range = false.
|
// Currently, only activation stats are imported, so narrow_range = false.
|
||||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
||||||
}
|
}
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
SanityCheckAndAdjustment(func);
|
SanityCheckAndAdjustment(func);
|
||||||
|
|
||||||
|
@ -619,8 +619,8 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
|
|
||||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
|
||||||
// TF FakeQuant ops by the constant folding.
|
// the TF FakeQuant ops by the constant folding.
|
||||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||||
|
|
||||||
// This pattern will try to identify and optimize for dilated convolution.
|
// This pattern will try to identify and optimize for dilated convolution.
|
||||||
@ -634,7 +634,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
// This will allow optimizing any TF_Mul->TF_Conv in the graph
|
// This will allow optimizing any TF_Mul->TF_Conv in the graph
|
||||||
// and any expanded from FusedBatchNorm. We need to do this
|
// and any expanded from FusedBatchNorm. We need to do this
|
||||||
// before converting TF_Conv to TFL_Conv
|
// before converting TF_Conv to TFL_Conv
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
// Load the generated pattern again, so new quantization pass-through
|
// Load the generated pattern again, so new quantization pass-through
|
||||||
// will be applied.
|
// will be applied.
|
||||||
@ -646,7 +646,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
||||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -88,7 +88,7 @@ void QuantizePass::runOnFunction() {
|
|||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
patterns.insert<TFLFullQuantization>(
|
patterns.insert<TFLFullQuantization>(
|
||||||
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ void BatchMatMulToEinsumPass::runOnFunction() {
|
|||||||
patterns.insert<ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulOp>,
|
patterns.insert<ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulOp>,
|
||||||
ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulV2Op>>(
|
ConvertTFBatchMatMulToEinsumOp<TF::BatchMatMulV2Op>>(
|
||||||
&getContext());
|
&getContext());
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -45,7 +45,7 @@ struct DecomposeResourceOps
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
|
mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
|
||||||
|
|
||||||
applyPatternsGreedily(getFunction(), patterns);
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -364,7 +364,7 @@ void TransformEinsumPass::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
|
|
||||||
patterns.insert<ConvertTFEinsumOp>(&getContext());
|
patterns.insert<ConvertTFEinsumOp>(&getContext());
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<TransformEinsumPass> pass(
|
static PassRegistration<TransformEinsumPass> pass(
|
||||||
|
@ -118,7 +118,7 @@ void GpuOpFusionPass::runOnFunction() {
|
|||||||
FuncOp func = getFunction();
|
FuncOp func = getFunction();
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<ReluToFusedBatchNorm>(&getContext());
|
patterns.insert<ReluToFusedBatchNorm>(&getContext());
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -29,7 +29,7 @@ struct LowerTF : public PassWrapper<LowerTF, FunctionPass> {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
mlir::TF::PopulateLoweringTFPatterns(&getContext(), &patterns);
|
mlir::TF::PopulateLoweringTFPatterns(&getContext(), &patterns);
|
||||||
|
|
||||||
applyPatternsGreedily(getFunction(), patterns);
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ struct TFOptimizePass : public PassWrapper<TFOptimizePass, FunctionPass> {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateWithGenerated(&getContext(), &patterns);
|
populateWithGenerated(&getContext(), &patterns);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ void UnrollBatchMatMulPass::runOnFunction() {
|
|||||||
|
|
||||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -81,7 +81,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
|||||||
// Emit an argument for an operand.
|
// Emit an argument for an operand.
|
||||||
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
|
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
|
||||||
// Handle a non-variadic operand.
|
// Handle a non-variadic operand.
|
||||||
if (!operand_cst->isVariadic()) {
|
if (!operand_cst->isVariableLength()) {
|
||||||
os << " auto xla_arg_" << index
|
os << " auto xla_arg_" << index
|
||||||
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
|
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
|
||||||
<< ").begin()];\n";
|
<< ").begin()];\n";
|
||||||
@ -108,7 +108,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
|||||||
|
|
||||||
// If all operands are variadic, then pass the builder explicitly to xla
|
// If all operands are variadic, then pass the builder explicitly to xla
|
||||||
// client API call
|
// client API call
|
||||||
if (op.getNumOperands() == op.getNumVariadicOperands()) {
|
if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
|
||||||
os << "lowering_context.builder";
|
os << "lowering_context.builder";
|
||||||
if (op.getNumArgs() != 0) os << ", ";
|
if (op.getNumArgs() != 0) os << ", ";
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
|||||||
void LegalizeToStandard::runOnFunction() {
|
void LegalizeToStandard::runOnFunction() {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||||
applyPatternsGreedily(getFunction(), patterns);
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||||
|
@ -87,7 +87,7 @@ struct LhloLegalizeToAffine
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
|
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ void LowerComplex::runOnFunction() {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||||
|
|
||||||
applyPatternsGreedily(getFunction(), patterns);
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<LowerComplex> pass(
|
static PassRegistration<LowerComplex> pass(
|
||||||
|
@ -178,7 +178,7 @@ struct LegalizeGeneralDot
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
||||||
&getContext());
|
&getContext());
|
||||||
applyPatternsGreedily(getFunction(), patterns);
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ struct TestUnfuseBatchNormPass
|
|||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||||
applyPatternsGreedily(getOperation(), patterns);
|
applyPatternsAndFoldGreedily(getOperation(), patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -79,11 +79,31 @@ class IrBuilderMixin {
|
|||||||
return mixin_builder()->CreateBr(std::forward<Args>(args)...);
|
return mixin_builder()->CreateBr(std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::CallInst* Call(llvm::FunctionCallee func_callee,
|
||||||
|
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
||||||
|
const llvm::Twine& name = "",
|
||||||
|
llvm::MDNode* fp_math_tag = nullptr) {
|
||||||
|
return mixin_builder()->CreateCall(func_callee, args, name, fp_math_tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::CallInst* Call(llvm::FunctionType* func_type, llvm::Value* callee,
|
||||||
|
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
||||||
|
const llvm::Twine& name = "",
|
||||||
|
llvm::MDNode* fp_math_tag = nullptr) {
|
||||||
|
return mixin_builder()->CreateCall(func_type, callee, args, name,
|
||||||
|
fp_math_tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEPRECATED. LLVM is removing getPointerElementType, so calls to this must
|
||||||
|
// be transitioned to one of the other overloads.
|
||||||
llvm::CallInst* Call(llvm::Value* callee,
|
llvm::CallInst* Call(llvm::Value* callee,
|
||||||
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
llvm::ArrayRef<llvm::Value*> args = llvm::None,
|
||||||
const llvm::Twine& name = "",
|
const llvm::Twine& name = "",
|
||||||
llvm::MDNode* fp_math_tag = nullptr) {
|
llvm::MDNode* fp_math_tag = nullptr) {
|
||||||
return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
|
return mixin_builder()->CreateCall(
|
||||||
|
llvm::cast<llvm::FunctionType>(
|
||||||
|
callee->getType()->getPointerElementType()),
|
||||||
|
callee, args, name, fp_math_tag);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class... Args>
|
template <class... Args>
|
||||||
|
14
third_party/mlir/BUILD
vendored
14
third_party/mlir/BUILD
vendored
@ -2666,6 +2666,20 @@ cc_binary(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "mlir-linalg-ods-gen",
|
||||||
|
srcs = glob([
|
||||||
|
"tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp",
|
||||||
|
]),
|
||||||
|
deps = [
|
||||||
|
":IR",
|
||||||
|
":Support",
|
||||||
|
"@llvm-project//llvm:config",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//llvm:tablegen",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
## OpenMP dialect
|
## OpenMP dialect
|
||||||
gentbl(
|
gentbl(
|
||||||
name = "OpenMPOpsIncGen",
|
name = "OpenMPOpsIncGen",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user