From 4fb7a4a195d0aba3e4c884f4ecd77d97fdc7a3ce Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 22 Jan 2020 22:01:40 -0800 Subject: [PATCH] Support direct BatchNorm folding when it is quantization aware training Previously the BatchNorm has been folded in the training graph, so it isn't an issue. Folding BatchNorm in training can make it hard to converage, and we are investigating whether the folding can be done by a converter. PiperOrigin-RevId: 291096886 Change-Id: Idf5e27e12982766837546aea89a9fed296983970 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../lite/quantization/quantization_utils.cc | 31 ++++ .../lite/quantization/quantization_utils.h | 6 + .../compiler/mlir/lite/tests/optimize.mlir | 19 +++ .../compiler/mlir/lite/transforms/optimize.cc | 151 ++++++++++++++---- 5 files changed, 174 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index d07a83c58ab..95d92670b81 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -325,6 +325,7 @@ cc_library( deps = [ ":tensorflow_lite", ":validators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 5ff4ffa4b92..4f2efc3710e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -66,6 +66,37 @@ static Type GetQuantizedType(Builder builder, Type input_type, return converter.convert(quantizedEleType); } +// TODO(fengliuai): promote this utility method to mlir QuantOps. +TypeAttr RescaleQuantizedType(Type input, Attribute factor) { + auto factor_values = factor.dyn_cast_or_null(); + if (!factor_values) return {}; + auto ele_type = quant::QuantizedType::getQuantizedElementType(input); + if (!ele_type) return {}; + if (auto qtype = ele_type.dyn_cast()) { + ArrayRef scales = qtype.getScales(); + // Broadcasting hasn't been implemented yet. + if (scales.size() != factor_values.getNumElements()) return {}; + SmallVector new_scales; + new_scales.reserve(scales.size()); + auto scales_iter = scales.begin(); + for (auto f : factor_values) { + new_scales.push_back(*(scales_iter++) * + std::fabs(FloatAttr::getValueAsDouble(f))); + } + // We are assuming symmetric quantization. + auto new_ele_type = quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(), + qtype.getStorageTypeMin(), qtype.getStorageTypeMax()); + if (auto new_type = new_ele_type.castFromExpressedType( + quant::QuantizedType::castToExpressedType(input))) { + return TypeAttr::get(new_type); + } + } + // Currently, we only support per-axis quantized type. + return {}; +} + TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, Attribute max, int quant_dim, IntegerAttr num_bits, BoolAttr narrow_range, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 60fc2add989..e9a865780dd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -352,6 +352,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { return this->matchFailure(); } + if (!new_qtype) return this->matchFailure(); Type new_output_type = new_qtype.castFromExpressedType( QType::castToExpressedType(output_type)); rewriter.replaceOpWithNewOp(op, new_output_type, op.input(), @@ -360,6 +361,11 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { } }; +// Given a quantized type `input`, magnifying its scales by the factor stored in +// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the +// dimension size of `input` or isn't floating-point, nullptr will be returned. +TypeAttr RescaleQuantizedType(Type input, Attribute factor); + // Converts the min/max/num_bits/narrow_range information to a // QuantizedType, and then returns the attribute containing the QuantizedType. // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 1c29891b609..f09d338aef6 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -143,6 +143,25 @@ func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: // CHECK-SAME: fused_activation_function = "RELU6" } +// CHECK-LABEL: fuseMulIntoConv2dWithQDQs +func @fuseMulIntoConv2dWithQDQs(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x3xf32> { + %cst = constant dense<1.5> : tensor<3xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + %w = constant dense<2.0> : tensor<3x3x3x3xf32> + %q = "tfl.quantize"(%w) {qtype = tensor<3x3x3x3x!quant.uniform:f32:0,{1.0,2.0,3.0}>>} : (tensor<3x3x3x3xf32>) -> tensor<3x3x3x3x!quant.uniform:f32:0,{1.0,2.0,3.0}>> + %dq = "tfl.dequantize"(%q) : (tensor<3x3x3x3x!quant.uniform:f32:0,{1.0,2.0,3.0}>>) -> tensor<3x3x3x3xf32> + %0 = "tfl.conv_2d"(%arg0, %dq, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x3xf32>, tensor<3xf32>) -> tensor<256x30x30x3xf32> + %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x3xf32>, tensor<3xf32>) -> tensor<256x30x30x3xf32> + return %1 : tensor<256x30x30x3xf32> + + // CHECK: %[[w:.*]] = constant dense<3.000000e+00> : tensor<3x3x3x3xf32> + // CHECK: %[[cst:.*]] = constant dense<[1.500000e+00, 3.000000e+00, 4.500000e+00]> : tensor<3xf32> + // CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>} + // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) + // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[cst]]) + // CHECK: return %[[conv]] : tensor<256x30x30x3xf32> +} + // CHECK-LABEL: @fuseMulIntoFullyConnected func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 327cd15c2b8..ddc6169f0c9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -171,6 +172,10 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) { return ExpandTo4DForConvImpl(a, true); } +TypeAttr RescaleQtype(Type input, Attribute factor) { + return TFL::RescaleQuantizedType(input, factor); +} + // Returns shape of a ranked tensor. // Precondition: output_val's is ranked tensor. DenseElementsAttr GetShape(Value output_val) { @@ -287,7 +292,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&cst_tmp))) return matchFailure(); - if (fc_op.fused_activation_function().equals("None")) return matchFailure(); + if (fc_op.fused_activation_function() != "NONE") return matchFailure(); // Broadcast the constant operand of Mul if it isn't compatible to the // filter input. We only support broadcasting the operand along the depth @@ -334,8 +339,110 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { } }; +// Fuse Mul with proceeding Affine ops. This is an C++ implementation of the +// following table gen implementation, which doesn't derived the result type of +// the TFL_DequantizeOp. +// def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input, +// (TFL_DequantizeOp (TFL_QuantizeOp +// (ConstantOp F32ElementsAttr:$filter), $qtype)), +// (ConstantOp F32ElementsAttr:$bias), +// $h_factor, $w_factor, TFL_AF_None, +// $padding, $stride_h, $stride_w), +// (ConstantOp F32ElementsAttr:$value), $act_fn), +// (TFL_Conv2DOp $input, +// (TFL_DequantizeOp (TFL_QuantizeOp +// (TFL_MulOp (ConstantOp $filter), +// (ConstantOp (ExpandTo4DForConv $value)), +// TFL_AF_None), +// (RescaleQtype $qtype, $value))), +// (TFL_MulOp (ConstantOp $bias), (ConstantOp $value), +// TFL_AF_None), +// $h_factor, $w_factor, $act_fn, +// $padding, $stride_h, $stride_w), +// [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), +// (HasOneUse $conv_output), +// (IsPerAxisQuantization $qtype), // per-axis quantization +// ]>; +template +struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TFL::MulOp mul_op, + PatternRewriter &rewriter) const override { + // Mul. Required 1-D rhs for batch normalization. + DenseElementsAttr gamma_cst; + Value gamma = mul_op.rhs(); + if (!matchPattern(gamma, m_Constant(&gamma_cst))) return matchFailure(); + if (gamma_cst.getType().getRank() != 1) return matchFailure(); + + // Affine op + Operation *mul_op_lhs = mul_op.lhs().getDefiningOp(); + auto fc_op = dyn_cast_or_null(mul_op_lhs); + if (!fc_op) return matchFailure(); + Value filter = fc_op.filter(); + Value bias = fc_op.bias(); + + // QDQs + auto dq_op = dyn_cast_or_null(filter.getDefiningOp()); + if (!dq_op) return matchFailure(); + auto q_op = + dyn_cast_or_null(dq_op.input().getDefiningOp()); + if (!q_op) return matchFailure(); + filter = q_op.input(); + + // weight constant + ElementsAttr cst_tmp; + if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); + if (!bias.getType().isa() && + !matchPattern(bias, m_Constant(&cst_tmp))) + return matchFailure(); + if (fc_op.fused_activation_function() != "NONE") return matchFailure(); + + // Broadcast the constant operand of Mul if it isn't compatible to the + // filter input. We only support broadcasting the operand along the depth + // dimension, when the operand's depth is 1. + rewriter.setInsertionPoint(q_op); + Location loc = fc_op.getLoc(); + Value broadcasted_gamma; + if (isa(mul_op_lhs)) { + auto mul_rhs = ExpandTo4DForConv(gamma_cst); + broadcasted_gamma = rewriter.create(loc, mul_rhs); + } else if (isa(mul_op_lhs)) { + auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst); + broadcasted_gamma = rewriter.create(loc, mul_rhs); + } else { + return matchFailure(); + } + + // Rewrite filter constant. Since the folder of TFL::MulOp couldn't + // broadcast the operands, TF::MulOp is used to fold the constant. + auto new_filter = + rewriter.create(loc, filter, broadcasted_gamma).z(); + // Update the scale in the quantize op. + auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst); + if (!new_qtype) return matchFailure(); + rewriter.replaceOpWithNewOp(q_op, new_qtype.getValue(), + new_filter, new_qtype); + + // If bias isn't None, it needs to be multiplied as well. + if (!bias.getType().isa()) { + rewriter.setInsertionPoint(fc_op); + auto new_bias = rewriter.create(loc, bias, gamma); + fc_op.getOperation()->replaceUsesOfWith(bias, new_bias); + } + + // Remove the tailing mul op. + mul_op.replaceAllUsesWith(fc_op.getResult()); + return matchSuccess(); + } +}; + +using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs; +using FuseDepthwiseConv2DAndMulWithQDQs = + FuseAffinOpAndMulWithQDQs; + // Fuse Binary Op with following Affine operation. -template +template struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -469,37 +576,11 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { } }; -class FuseBinaryOpToFollowingFullyConnected - : public FuseBinaryOpToFollowingAffineOp< - FuseBinaryOpToFollowingFullyConnected, FullyConnectedOp> { - public: - using BaseType = - FuseBinaryOpToFollowingAffineOp; - explicit FuseBinaryOpToFollowingFullyConnected(MLIRContext *context) - : BaseType(context) {} -}; - -class FuseBinaryOpToFollowingDepthwiseConv2D - : public FuseBinaryOpToFollowingAffineOp< - FuseBinaryOpToFollowingDepthwiseConv2D, DepthwiseConv2DOp> { - public: - using BaseType = - FuseBinaryOpToFollowingAffineOp; - explicit FuseBinaryOpToFollowingDepthwiseConv2D(MLIRContext *context) - : BaseType(context) {} -}; - -class FuseBinaryOpToFollowingConv2D - : public FuseBinaryOpToFollowingAffineOp { - public: - using BaseType = - FuseBinaryOpToFollowingAffineOp; - explicit FuseBinaryOpToFollowingConv2D(MLIRContext *context) - : BaseType(context) {} -}; +using FuseBinaryOpToFollowingFullyConnected = + FuseBinaryOpToFollowingAffineOp; +using FuseBinaryOpToFollowingDepthwiseConv2D = + FuseBinaryOpToFollowingAffineOp; +using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp; void Optimize::runOnFunction() { OwningRewritePatternList patterns; @@ -517,7 +598,9 @@ void Optimize::runOnFunction() { // Fuse the binary ops with the following ops. patterns.insert(ctx); + FuseBinaryOpToFollowingFullyConnected, + FuseConv2DAndMulWithQDQs, FuseDepthwiseConv2DAndMulWithQDQs>( + ctx); applyPatternsGreedily(func, patterns); }