Create a generic quantization pattern for quantization graph rewrite

This pattern assumes that the quantization parameter propagation pass has
propagated the quantization parameters to all the quantizable ops. Instead of
defining table gen pattern for each of the quantizable ops, we can use this
single pattern to match all the ops.

The tests are mainly the existing ones, plus some manual tests.

PiperOrigin-RevId: 257213495
This commit is contained in:
Feng Liu 2019-07-09 09:50:10 -07:00 committed by TensorFlower Gardener
parent 6df6fd298c
commit 67fc29f8fb
3 changed files with 57 additions and 135 deletions

View File

@ -50,46 +50,13 @@ struct QuantizePass : public FunctionPass<QuantizePass> {
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
struct QuantizeConcatOp : public RewritePattern {
explicit QuantizeConcatOp(MLIRContext* context)
: RewritePattern(QuantizeOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
};
PatternMatchResult mlir::TFL::QuantizeConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto quantize_op = cast<QuantizeOp>(op);
auto concat_op =
dyn_cast_or_null<ConcatenationOp>(quantize_op.input()->getDefiningOp());
if (!concat_op) {
return matchFailure();
}
SmallVector<Value*, 4> values;
values.reserve(concat_op.getNumOperands());
for (auto operand : concat_op.values()) {
if (auto opInst =
dyn_cast_or_null<DequantizeOp>(operand->getDefiningOp())) {
values.push_back(opInst.input());
} else {
return matchFailure();
}
}
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, quantize_op.output()->getType(), values,
rewriter.getI32IntegerAttr(concat_op.axis().getZExtValue()),
rewriter.getStringAttr(concat_op.fused_activation_function()));
return matchSuccess();
}
void QuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns);
mlir::RewriteListBuilder<mlir::TFL::QuantizeConcatOp>::build(patterns, ctx);
mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
applyPatternsGreedily(func, std::move(patterns));
}
} // namespace

View File

@ -22,10 +22,6 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1.
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
// Call the generic builder of `op`. Use the result type of $0 in the new op.
class ReplaceWith<string op> : NativeCodeCall<"$_builder.create<" # op #
">($0->getLoc(), $0->getResult(0)->getType(), $1, $2, $3)">;
// Squash tfl.dequantize and tfl.quantize pairs.
// TODO(fengliuai): Compare the scale of input and output. This can also be
// squashed to a requantize op if the scales are different.
@ -39,98 +35,3 @@ def : Pat<(TFL_QuantizeOp
(TFL_QConstOp
$qtype,
(QuantizeByQuantizedType $value, $qtype))>;
// Quantize the AddOp if both inputs are dequantized and the output is
// quantized.
def : Pat<(TFL_QuantizeOp:$q
(TFL_AddOp (TFL_DequantizeOp $lhs), (TFL_DequantizeOp $rhs),
$fused_activation_function),
$output_type),
(ReplaceWith<"TFL::AddOp"> $q, $lhs, $rhs,
$fused_activation_function)>;
// Quantize the Conv2DOp if the input and weight are dequantized. The scale of
// the bias input is determined by the scales of input and weight operands.
def : Pat<(TFL_QuantizeOp
(TFL_Conv2DOp
(TFL_DequantizeOp $in),
(TFL_DequantizeOp $weight),
(TFL_DequantizeOp $bias),
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w),
$output_type),
(TFL_Conv2DOp
$in,
$weight,
$bias,
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w)>;
// Quantize the DepthwiseConv2DOp if the input and weight are dequantized. The
// scale of the bias input is determined by the scales of input and weight
// operands.
def : Pat<(TFL_QuantizeOp
(TFL_DepthwiseConv2DOp
(TFL_DequantizeOp $in),
(TFL_DequantizeOp $weight),
(TFL_DequantizeOp $bias),
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w,
$multiplier),
$output_type),
(TFL_DepthwiseConv2DOp
$in,
$weight,
$bias,
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w,
$multiplier)>;
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
// The pre-quantize pass can guarantee both quantization parameters are the
// same.
def : Pat<(TFL_QuantizeOp (TFL_ReshapeOp (TFL_DequantizeOp $in)), $output_type),
(TFL_ReshapeOp $in)>;
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
// The pre-quantize pass has set the output quantization parameters to a
// pre-defined value.
def : Pat<(TFL_QuantizeOp (TFL_SoftmaxOp (TFL_DequantizeOp $in), $beta),
$output_type),
(TFL_SoftmaxOp $in, $beta)>;
// Quantize the AveragePool2DOp if the input is dequantized and output is
// quantized. The pre-quantize pass can guarantee both quantization parameters
// are the same.
def : Pat<(TFL_QuantizeOp (TFL_AveragePool2DOp (TFL_DequantizeOp $in),
$filter_height, $filter_width, $fused_activation_function,
$padding, $stride_h, $stride_w), $output_type),
(TFL_AveragePool2DOp $in,
$filter_height, $filter_width, $fused_activation_function,
$padding, $stride_h, $stride_w)>;
// Quantize the MaxPool2DOp if the input is dequantized and output is
// quantized. The pre-quantize pass can guarantee both quantization parameters
// are the same.
def : Pat<(TFL_QuantizeOp (TFL_MaxPool2DOp (TFL_DequantizeOp $in),
$padding, $stride_w, $tride_h, $stride_width, $stride_height,
$fused_activation_function), $output_type),
(TFL_MaxPool2DOp $in,
$padding, $stride_w, $tride_h, $stride_width, $stride_height,
$fused_activation_function)>;

View File

@ -20,12 +20,66 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
namespace mlir {
namespace TFL {
// A generic rewrite pattern which matches any N-in-1-out operations with
// quantization parameters propagated to all the operands and results values.
// The quantization parameters are annotated by the Q/DQ op pairs. Each matched
// pattern are rewritten by its quantized alternatives.
//
// This pattern assumes all the matched ops are quantizable. This assumption is
// always right, except when a "Q" op is used as a requantize op. For non-"Q"
// ops, quantization parameters should be propagated to their result.
//
// This pattern only matches ops which only have one result.
template <typename Q, typename DQ>
struct GenericFullQuantizationPattern : public RewritePattern {
explicit GenericFullQuantizationPattern(MLIRContext* context)
: RewritePattern(Q::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
if (op->getNumResults() != 1) {
return matchFailure();
}
auto quantize_op = cast<Q>(op);
auto quantized_op = quantize_op.input()->getDefiningOp();
// If it is a block argument, requantize op, or has more than one result, we
// shouldn't rewrite this op.
if (!quantized_op || llvm::isa<Q>(quantized_op) ||
llvm::isa<DQ>(quantized_op) || quantized_op->getNumResults() != 1) {
return matchFailure();
}
// Collect all the quantized inputs and "clone" the matched op by these
// inputs.
SmallVector<Value*, 4> inputs;
inputs.reserve(quantized_op->getNumOperands());
for (int i = 0, e = quantized_op->getNumOperands(); i != e; ++i) {
auto* operand = quantized_op->getOperand(i);
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
inputs.push_back(op_inst.input());
} else {
return matchFailure();
}
}
// Use OpBuilder so we can use op name to create the new op.
OpBuilder builder(quantized_op);
OperationState new_state(
quantized_op->getLoc(), quantized_op->getName().getStringRef(), inputs,
op->getResult(0)->getType(), quantized_op->getAttrs());
Operation* new_op = builder.createOperation(new_state);
rewriter.replaceOp(op, {new_op->getResult(0)});
return matchSuccess();
}
};
// Converts the min/max/storage_type/narrow_range information to a
// QuantizedType, and then returns the attribute containing the QuantizedType.
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,