diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 7d98487ded2..6e7d060cca5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -50,52 +50,19 @@ struct QuantizePass : public FunctionPass { #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc" -struct QuantizeConcatOp : public RewritePattern { - explicit QuantizeConcatOp(MLIRContext* context) - : RewritePattern(QuantizeOp::getOperationName(), 1, context) {} - - PatternMatchResult matchAndRewrite(Operation* op, - PatternRewriter& rewriter) const override; -}; - -PatternMatchResult mlir::TFL::QuantizeConcatOp::matchAndRewrite( - Operation* op, PatternRewriter& rewriter) const { - auto quantize_op = cast(op); - auto concat_op = - dyn_cast_or_null(quantize_op.input()->getDefiningOp()); - if (!concat_op) { - return matchFailure(); - } - - SmallVector values; - values.reserve(concat_op.getNumOperands()); - for (auto operand : concat_op.values()) { - if (auto opInst = - dyn_cast_or_null(operand->getDefiningOp())) { - values.push_back(opInst.input()); - } else { - return matchFailure(); - } - } - rewriter.replaceOpWithNewOp( - op, quantize_op.output()->getType(), values, - rewriter.getI32IntegerAttr(concat_op.axis().getZExtValue()), - rewriter.getStringAttr(concat_op.fused_activation_function())); - return matchSuccess(); -} - void QuantizePass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); - mlir::RewriteListBuilder::build(patterns, ctx); + mlir::RewriteListBuilder>::build(patterns, ctx); applyPatternsGreedily(func, std::move(patterns)); } } // namespace // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. -FunctionPassBase *CreateQuantizePass() { return new QuantizePass(); } +FunctionPassBase* CreateQuantizePass() { return new QuantizePass(); } static PassRegistration pass( "tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect"); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index 756fae3a4cd..7fcf926d89f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -22,10 +22,6 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">; -// Call the generic builder of `op`. Use the result type of $0 in the new op. -class ReplaceWith : NativeCodeCall<"$_builder.create<" # op # - ">($0->getLoc(), $0->getResult(0)->getType(), $1, $2, $3)">; - // Squash tfl.dequantize and tfl.quantize pairs. // TODO(fengliuai): Compare the scale of input and output. This can also be // squashed to a requantize op if the scales are different. @@ -39,98 +35,3 @@ def : Pat<(TFL_QuantizeOp (TFL_QConstOp $qtype, (QuantizeByQuantizedType $value, $qtype))>; - -// Quantize the AddOp if both inputs are dequantized and the output is -// quantized. -def : Pat<(TFL_QuantizeOp:$q - (TFL_AddOp (TFL_DequantizeOp $lhs), (TFL_DequantizeOp $rhs), - $fused_activation_function), - $output_type), - (ReplaceWith<"TFL::AddOp"> $q, $lhs, $rhs, - $fused_activation_function)>; - -// Quantize the Conv2DOp if the input and weight are dequantized. The scale of -// the bias input is determined by the scales of input and weight operands. -def : Pat<(TFL_QuantizeOp - (TFL_Conv2DOp - (TFL_DequantizeOp $in), - (TFL_DequantizeOp $weight), - (TFL_DequantizeOp $bias), - $dilation_h_factor, - $dilation_w_factor, - $fused_activation_function, - $padding, - $stride_h, - $stride_w), - $output_type), - (TFL_Conv2DOp - $in, - $weight, - $bias, - $dilation_h_factor, - $dilation_w_factor, - $fused_activation_function, - $padding, - $stride_h, - $stride_w)>; - -// Quantize the DepthwiseConv2DOp if the input and weight are dequantized. The -// scale of the bias input is determined by the scales of input and weight -// operands. -def : Pat<(TFL_QuantizeOp - (TFL_DepthwiseConv2DOp - (TFL_DequantizeOp $in), - (TFL_DequantizeOp $weight), - (TFL_DequantizeOp $bias), - $dilation_h_factor, - $dilation_w_factor, - $fused_activation_function, - $padding, - $stride_h, - $stride_w, - $multiplier), - $output_type), - (TFL_DepthwiseConv2DOp - $in, - $weight, - $bias, - $dilation_h_factor, - $dilation_w_factor, - $fused_activation_function, - $padding, - $stride_h, - $stride_w, - $multiplier)>; - -// Quantize the ReshapeOp if the input is dequantized and output is quantized. -// The pre-quantize pass can guarantee both quantization parameters are the -// same. -def : Pat<(TFL_QuantizeOp (TFL_ReshapeOp (TFL_DequantizeOp $in)), $output_type), - (TFL_ReshapeOp $in)>; - -// Quantize the ReshapeOp if the input is dequantized and output is quantized. -// The pre-quantize pass has set the output quantization parameters to a -// pre-defined value. -def : Pat<(TFL_QuantizeOp (TFL_SoftmaxOp (TFL_DequantizeOp $in), $beta), - $output_type), - (TFL_SoftmaxOp $in, $beta)>; - -// Quantize the AveragePool2DOp if the input is dequantized and output is -// quantized. The pre-quantize pass can guarantee both quantization parameters -// are the same. -def : Pat<(TFL_QuantizeOp (TFL_AveragePool2DOp (TFL_DequantizeOp $in), - $filter_height, $filter_width, $fused_activation_function, - $padding, $stride_h, $stride_w), $output_type), - (TFL_AveragePool2DOp $in, - $filter_height, $filter_width, $fused_activation_function, - $padding, $stride_h, $stride_w)>; - -// Quantize the MaxPool2DOp if the input is dequantized and output is -// quantized. The pre-quantize pass can guarantee both quantization parameters -// are the same. -def : Pat<(TFL_QuantizeOp (TFL_MaxPool2DOp (TFL_DequantizeOp $in), - $padding, $stride_w, $tride_h, $stride_width, $stride_height, - $fused_activation_function), $output_type), - (TFL_MaxPool2DOp $in, - $padding, $stride_w, $tride_h, $stride_width, $stride_height, - $fused_activation_function)>; diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_utils.h b/tensorflow/compiler/mlir/lite/utils/quantization_utils.h index cee00e6be38..33fe39cb05c 100644 --- a/tensorflow/compiler/mlir/lite/utils/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/quantization_utils.h @@ -20,12 +20,66 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_ #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/StandardOps/Ops.h" // TF:local_config_mlir namespace mlir { namespace TFL { +// A generic rewrite pattern which matches any N-in-1-out operations with +// quantization parameters propagated to all the operands and results values. +// The quantization parameters are annotated by the Q/DQ op pairs. Each matched +// pattern are rewritten by its quantized alternatives. +// +// This pattern assumes all the matched ops are quantizable. This assumption is +// always right, except when a "Q" op is used as a requantize op. For non-"Q" +// ops, quantization parameters should be propagated to their result. +// +// This pattern only matches ops which only have one result. +template +struct GenericFullQuantizationPattern : public RewritePattern { + explicit GenericFullQuantizationPattern(MLIRContext* context) + : RewritePattern(Q::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (op->getNumResults() != 1) { + return matchFailure(); + } + auto quantize_op = cast(op); + auto quantized_op = quantize_op.input()->getDefiningOp(); + // If it is a block argument, requantize op, or has more than one result, we + // shouldn't rewrite this op. + if (!quantized_op || llvm::isa(quantized_op) || + llvm::isa(quantized_op) || quantized_op->getNumResults() != 1) { + return matchFailure(); + } + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantized_op->getNumOperands()); + for (int i = 0, e = quantized_op->getNumOperands(); i != e; ++i) { + auto* operand = quantized_op->getOperand(i); + if (auto op_inst = dyn_cast_or_null(operand->getDefiningOp())) { + inputs.push_back(op_inst.input()); + } else { + return matchFailure(); + } + } + // Use OpBuilder so we can use op name to create the new op. + OpBuilder builder(quantized_op); + OperationState new_state( + quantized_op->getLoc(), quantized_op->getName().getStringRef(), inputs, + op->getResult(0)->getType(), quantized_op->getAttrs()); + Operation* new_op = builder.createOperation(new_state); + rewriter.replaceOp(op, {new_op->getResult(0)}); + return matchSuccess(); + } +}; + // Converts the min/max/storage_type/narrow_range information to a // QuantizedType, and then returns the attribute containing the QuantizedType. TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,