diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index d07a83c58ab..95d92670b81 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -325,6 +325,7 @@ cc_library( deps = [ ":tensorflow_lite", ":validators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 5ff4ffa4b92..4f2efc3710e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -66,6 +66,37 @@ static Type GetQuantizedType(Builder builder, Type input_type, return converter.convert(quantizedEleType); } +// TODO(fengliuai): promote this utility method to mlir QuantOps. +TypeAttr RescaleQuantizedType(Type input, Attribute factor) { + auto factor_values = factor.dyn_cast_or_null(); + if (!factor_values) return {}; + auto ele_type = quant::QuantizedType::getQuantizedElementType(input); + if (!ele_type) return {}; + if (auto qtype = ele_type.dyn_cast()) { + ArrayRef scales = qtype.getScales(); + // Broadcasting hasn't been implemented yet. + if (scales.size() != factor_values.getNumElements()) return {}; + SmallVector new_scales; + new_scales.reserve(scales.size()); + auto scales_iter = scales.begin(); + for (auto f : factor_values) { + new_scales.push_back(*(scales_iter++) * + std::fabs(FloatAttr::getValueAsDouble(f))); + } + // We are assuming symmetric quantization. + auto new_ele_type = quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(), + qtype.getStorageTypeMin(), qtype.getStorageTypeMax()); + if (auto new_type = new_ele_type.castFromExpressedType( + quant::QuantizedType::castToExpressedType(input))) { + return TypeAttr::get(new_type); + } + } + // Currently, we only support per-axis quantized type. + return {}; +} + TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, Attribute max, int quant_dim, IntegerAttr num_bits, BoolAttr narrow_range, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 60fc2add989..e9a865780dd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -352,6 +352,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { return this->matchFailure(); } + if (!new_qtype) return this->matchFailure(); Type new_output_type = new_qtype.castFromExpressedType( QType::castToExpressedType(output_type)); rewriter.replaceOpWithNewOp(op, new_output_type, op.input(), @@ -360,6 +361,11 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { } }; +// Given a quantized type `input`, magnifying its scales by the factor stored in +// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the +// dimension size of `input` or isn't floating-point, nullptr will be returned. +TypeAttr RescaleQuantizedType(Type input, Attribute factor); + // Converts the min/max/num_bits/narrow_range information to a // QuantizedType, and then returns the attribute containing the QuantizedType. // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 1c29891b609..f09d338aef6 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -143,6 +143,25 @@ func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: // CHECK-SAME: fused_activation_function = "RELU6" } +// CHECK-LABEL: fuseMulIntoConv2dWithQDQs +func @fuseMulIntoConv2dWithQDQs(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x3xf32> { + %cst = constant dense<1.5> : tensor<3xf32> + %cst_0 = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + %w = constant dense<2.0> : tensor<3x3x3x3xf32> + %q = "tfl.quantize"(%w) {qtype = tensor<3x3x3x3x!quant.uniform:f32:0,{1.0,2.0,3.0}>>} : (tensor<3x3x3x3xf32>) -> tensor<3x3x3x3x!quant.uniform:f32:0,{1.0,2.0,3.0}>> + %dq = "tfl.dequantize"(%q) : (tensor<3x3x3x3x!quant.uniform:f32:0,{1.0,2.0,3.0}>>) -> tensor<3x3x3x3xf32> + %0 = "tfl.conv_2d"(%arg0, %dq, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x3xf32>, tensor<3xf32>) -> tensor<256x30x30x3xf32> + %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x3xf32>, tensor<3xf32>) -> tensor<256x30x30x3xf32> + return %1 : tensor<256x30x30x3xf32> + + // CHECK: %[[w:.*]] = constant dense<3.000000e+00> : tensor<3x3x3x3xf32> + // CHECK: %[[cst:.*]] = constant dense<[1.500000e+00, 3.000000e+00, 4.500000e+00]> : tensor<3xf32> + // CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>} + // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) + // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[cst]]) + // CHECK: return %[[conv]] : tensor<256x30x30x3xf32> +} + // CHECK-LABEL: @fuseMulIntoFullyConnected func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 327cd15c2b8..ddc6169f0c9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -171,6 +172,10 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) { return ExpandTo4DForConvImpl(a, true); } +TypeAttr RescaleQtype(Type input, Attribute factor) { + return TFL::RescaleQuantizedType(input, factor); +} + // Returns shape of a ranked tensor. // Precondition: output_val's is ranked tensor. DenseElementsAttr GetShape(Value output_val) { @@ -287,7 +292,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { if (!bias.getType().isa() && !matchPattern(bias, m_Constant(&cst_tmp))) return matchFailure(); - if (fc_op.fused_activation_function().equals("None")) return matchFailure(); + if (fc_op.fused_activation_function() != "NONE") return matchFailure(); // Broadcast the constant operand of Mul if it isn't compatible to the // filter input. We only support broadcasting the operand along the depth @@ -334,8 +339,110 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { } }; +// Fuse Mul with proceeding Affine ops. This is an C++ implementation of the +// following table gen implementation, which doesn't derived the result type of +// the TFL_DequantizeOp. +// def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input, +// (TFL_DequantizeOp (TFL_QuantizeOp +// (ConstantOp F32ElementsAttr:$filter), $qtype)), +// (ConstantOp F32ElementsAttr:$bias), +// $h_factor, $w_factor, TFL_AF_None, +// $padding, $stride_h, $stride_w), +// (ConstantOp F32ElementsAttr:$value), $act_fn), +// (TFL_Conv2DOp $input, +// (TFL_DequantizeOp (TFL_QuantizeOp +// (TFL_MulOp (ConstantOp $filter), +// (ConstantOp (ExpandTo4DForConv $value)), +// TFL_AF_None), +// (RescaleQtype $qtype, $value))), +// (TFL_MulOp (ConstantOp $bias), (ConstantOp $value), +// TFL_AF_None), +// $h_factor, $w_factor, $act_fn, +// $padding, $stride_h, $stride_w), +// [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), +// (HasOneUse $conv_output), +// (IsPerAxisQuantization $qtype), // per-axis quantization +// ]>; +template +struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TFL::MulOp mul_op, + PatternRewriter &rewriter) const override { + // Mul. Required 1-D rhs for batch normalization. + DenseElementsAttr gamma_cst; + Value gamma = mul_op.rhs(); + if (!matchPattern(gamma, m_Constant(&gamma_cst))) return matchFailure(); + if (gamma_cst.getType().getRank() != 1) return matchFailure(); + + // Affine op + Operation *mul_op_lhs = mul_op.lhs().getDefiningOp(); + auto fc_op = dyn_cast_or_null(mul_op_lhs); + if (!fc_op) return matchFailure(); + Value filter = fc_op.filter(); + Value bias = fc_op.bias(); + + // QDQs + auto dq_op = dyn_cast_or_null(filter.getDefiningOp()); + if (!dq_op) return matchFailure(); + auto q_op = + dyn_cast_or_null(dq_op.input().getDefiningOp()); + if (!q_op) return matchFailure(); + filter = q_op.input(); + + // weight constant + ElementsAttr cst_tmp; + if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); + if (!bias.getType().isa() && + !matchPattern(bias, m_Constant(&cst_tmp))) + return matchFailure(); + if (fc_op.fused_activation_function() != "NONE") return matchFailure(); + + // Broadcast the constant operand of Mul if it isn't compatible to the + // filter input. We only support broadcasting the operand along the depth + // dimension, when the operand's depth is 1. + rewriter.setInsertionPoint(q_op); + Location loc = fc_op.getLoc(); + Value broadcasted_gamma; + if (isa(mul_op_lhs)) { + auto mul_rhs = ExpandTo4DForConv(gamma_cst); + broadcasted_gamma = rewriter.create(loc, mul_rhs); + } else if (isa(mul_op_lhs)) { + auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst); + broadcasted_gamma = rewriter.create(loc, mul_rhs); + } else { + return matchFailure(); + } + + // Rewrite filter constant. Since the folder of TFL::MulOp couldn't + // broadcast the operands, TF::MulOp is used to fold the constant. + auto new_filter = + rewriter.create(loc, filter, broadcasted_gamma).z(); + // Update the scale in the quantize op. + auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst); + if (!new_qtype) return matchFailure(); + rewriter.replaceOpWithNewOp(q_op, new_qtype.getValue(), + new_filter, new_qtype); + + // If bias isn't None, it needs to be multiplied as well. + if (!bias.getType().isa()) { + rewriter.setInsertionPoint(fc_op); + auto new_bias = rewriter.create(loc, bias, gamma); + fc_op.getOperation()->replaceUsesOfWith(bias, new_bias); + } + + // Remove the tailing mul op. + mul_op.replaceAllUsesWith(fc_op.getResult()); + return matchSuccess(); + } +}; + +using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs; +using FuseDepthwiseConv2DAndMulWithQDQs = + FuseAffinOpAndMulWithQDQs; + // Fuse Binary Op with following Affine operation. -template +template struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -469,37 +576,11 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { } }; -class FuseBinaryOpToFollowingFullyConnected - : public FuseBinaryOpToFollowingAffineOp< - FuseBinaryOpToFollowingFullyConnected, FullyConnectedOp> { - public: - using BaseType = - FuseBinaryOpToFollowingAffineOp; - explicit FuseBinaryOpToFollowingFullyConnected(MLIRContext *context) - : BaseType(context) {} -}; - -class FuseBinaryOpToFollowingDepthwiseConv2D - : public FuseBinaryOpToFollowingAffineOp< - FuseBinaryOpToFollowingDepthwiseConv2D, DepthwiseConv2DOp> { - public: - using BaseType = - FuseBinaryOpToFollowingAffineOp; - explicit FuseBinaryOpToFollowingDepthwiseConv2D(MLIRContext *context) - : BaseType(context) {} -}; - -class FuseBinaryOpToFollowingConv2D - : public FuseBinaryOpToFollowingAffineOp { - public: - using BaseType = - FuseBinaryOpToFollowingAffineOp; - explicit FuseBinaryOpToFollowingConv2D(MLIRContext *context) - : BaseType(context) {} -}; +using FuseBinaryOpToFollowingFullyConnected = + FuseBinaryOpToFollowingAffineOp; +using FuseBinaryOpToFollowingDepthwiseConv2D = + FuseBinaryOpToFollowingAffineOp; +using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp; void Optimize::runOnFunction() { OwningRewritePatternList patterns; @@ -517,7 +598,9 @@ void Optimize::runOnFunction() { // Fuse the binary ops with the following ops. patterns.insert(ctx); + FuseBinaryOpToFollowingFullyConnected, + FuseConv2DAndMulWithQDQs, FuseDepthwiseConv2DAndMulWithQDQs>( + ctx); applyPatternsGreedily(func, patterns); }