Support direct BatchNorm folding when it is quantization aware training
Previously the BatchNorm has been folded in the training graph, so it isn't an issue. Folding BatchNorm in training can make it hard to converage, and we are investigating whether the folding can be done by a converter. PiperOrigin-RevId: 291096886 Change-Id: Idf5e27e12982766837546aea89a9fed296983970
This commit is contained in:
parent
e4fe5890d7
commit
4fb7a4a195
@ -325,6 +325,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
|
@ -66,6 +66,37 @@ static Type GetQuantizedType(Builder builder, Type input_type,
|
||||
return converter.convert(quantizedEleType);
|
||||
}
|
||||
|
||||
// TODO(fengliuai): promote this utility method to mlir QuantOps.
|
||||
TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
|
||||
auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
|
||||
if (!factor_values) return {};
|
||||
auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
|
||||
if (!ele_type) return {};
|
||||
if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
|
||||
ArrayRef<double> scales = qtype.getScales();
|
||||
// Broadcasting hasn't been implemented yet.
|
||||
if (scales.size() != factor_values.getNumElements()) return {};
|
||||
SmallVector<double, 4> new_scales;
|
||||
new_scales.reserve(scales.size());
|
||||
auto scales_iter = scales.begin();
|
||||
for (auto f : factor_values) {
|
||||
new_scales.push_back(*(scales_iter++) *
|
||||
std::fabs(FloatAttr::getValueAsDouble(f)));
|
||||
}
|
||||
// We are assuming symmetric quantization.
|
||||
auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
|
||||
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
|
||||
new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
|
||||
qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
|
||||
if (auto new_type = new_ele_type.castFromExpressedType(
|
||||
quant::QuantizedType::castToExpressedType(input))) {
|
||||
return TypeAttr::get(new_type);
|
||||
}
|
||||
}
|
||||
// Currently, we only support per-axis quantized type.
|
||||
return {};
|
||||
}
|
||||
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
|
||||
Attribute max, int quant_dim,
|
||||
IntegerAttr num_bits, BoolAttr narrow_range,
|
||||
|
@ -352,6 +352,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (!new_qtype) return this->matchFailure();
|
||||
Type new_output_type = new_qtype.castFromExpressedType(
|
||||
QType::castToExpressedType(output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
|
||||
@ -360,6 +361,11 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
}
|
||||
};
|
||||
|
||||
// Given a quantized type `input`, magnifying its scales by the factor stored in
|
||||
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
|
||||
// dimension size of `input` or isn't floating-point, nullptr will be returned.
|
||||
TypeAttr RescaleQuantizedType(Type input, Attribute factor);
|
||||
|
||||
// Converts the min/max/num_bits/narrow_range information to a
|
||||
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
||||
// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
|
||||
|
@ -143,6 +143,25 @@ func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1:
|
||||
// CHECK-SAME: fused_activation_function = "RELU6"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseMulIntoConv2dWithQDQs
|
||||
func @fuseMulIntoConv2dWithQDQs(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x3xf32> {
|
||||
%cst = constant dense<1.5> : tensor<3xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
|
||||
%w = constant dense<2.0> : tensor<3x3x3x3xf32>
|
||||
%q = "tfl.quantize"(%w) {qtype = tensor<3x3x3x3x!quant.uniform<i8<-127:127>:f32:0,{1.0,2.0,3.0}>>} : (tensor<3x3x3x3xf32>) -> tensor<3x3x3x3x!quant.uniform<i8<-127:127>:f32:0,{1.0,2.0,3.0}>>
|
||||
%dq = "tfl.dequantize"(%q) : (tensor<3x3x3x3x!quant.uniform<i8<-127:127>:f32:0,{1.0,2.0,3.0}>>) -> tensor<3x3x3x3xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %dq, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x3xf32>, tensor<3xf32>) -> tensor<256x30x30x3xf32>
|
||||
%1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x3xf32>, tensor<3xf32>) -> tensor<256x30x30x3xf32>
|
||||
return %1 : tensor<256x30x30x3xf32>
|
||||
|
||||
// CHECK: %[[w:.*]] = constant dense<3.000000e+00> : tensor<3x3x3x3xf32>
|
||||
// CHECK: %[[cst:.*]] = constant dense<[1.500000e+00, 3.000000e+00, 4.500000e+00]> : tensor<3xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<3x3x3x3x!quant.uniform<i8<-127:127>:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[cst]])
|
||||
// CHECK: return %[[conv]] : tensor<256x30x30x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFullyConnected
|
||||
func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -171,6 +172,10 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
||||
return ExpandTo4DForConvImpl(a, true);
|
||||
}
|
||||
|
||||
TypeAttr RescaleQtype(Type input, Attribute factor) {
|
||||
return TFL::RescaleQuantizedType(input, factor);
|
||||
}
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
// Precondition: output_val's is ranked tensor.
|
||||
DenseElementsAttr GetShape(Value output_val) {
|
||||
@ -287,7 +292,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
@ -334,8 +339,110 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
|
||||
// following table gen implementation, which doesn't derived the result type of
|
||||
// the TFL_DequantizeOp.
|
||||
// def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
|
||||
// (TFL_DequantizeOp (TFL_QuantizeOp
|
||||
// (ConstantOp F32ElementsAttr:$filter), $qtype)),
|
||||
// (ConstantOp F32ElementsAttr:$bias),
|
||||
// $h_factor, $w_factor, TFL_AF_None,
|
||||
// $padding, $stride_h, $stride_w),
|
||||
// (ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
// (TFL_Conv2DOp $input,
|
||||
// (TFL_DequantizeOp (TFL_QuantizeOp
|
||||
// (TFL_MulOp (ConstantOp $filter),
|
||||
// (ConstantOp (ExpandTo4DForConv $value)),
|
||||
// TFL_AF_None),
|
||||
// (RescaleQtype $qtype, $value))),
|
||||
// (TFL_MulOp (ConstantOp $bias), (ConstantOp $value),
|
||||
// TFL_AF_None),
|
||||
// $h_factor, $w_factor, $act_fn,
|
||||
// $padding, $stride_h, $stride_w),
|
||||
// [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
// (HasOneUse $conv_output),
|
||||
// (IsPerAxisQuantization $qtype), // per-axis quantization
|
||||
// ]>;
|
||||
template <typename AffineOpType>
|
||||
struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
|
||||
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::MulOp mul_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Mul. Required 1-D rhs for batch normalization.
|
||||
DenseElementsAttr gamma_cst;
|
||||
Value gamma = mul_op.rhs();
|
||||
if (!matchPattern(gamma, m_Constant(&gamma_cst))) return matchFailure();
|
||||
if (gamma_cst.getType().getRank() != 1) return matchFailure();
|
||||
|
||||
// Affine op
|
||||
Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
|
||||
auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
|
||||
if (!fc_op) return matchFailure();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
|
||||
// QDQs
|
||||
auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
|
||||
if (!dq_op) return matchFailure();
|
||||
auto q_op =
|
||||
dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
|
||||
if (!q_op) return matchFailure();
|
||||
filter = q_op.input();
|
||||
|
||||
// weight constant
|
||||
ElementsAttr cst_tmp;
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
// dimension, when the operand's depth is 1.
|
||||
rewriter.setInsertionPoint(q_op);
|
||||
Location loc = fc_op.getLoc();
|
||||
Value broadcasted_gamma;
|
||||
if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
|
||||
auto mul_rhs = ExpandTo4DForConv(gamma_cst);
|
||||
broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
|
||||
} else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
|
||||
auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
|
||||
broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
|
||||
} else {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Rewrite filter constant. Since the folder of TFL::MulOp couldn't
|
||||
// broadcast the operands, TF::MulOp is used to fold the constant.
|
||||
auto new_filter =
|
||||
rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
|
||||
// Update the scale in the quantize op.
|
||||
auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
|
||||
if (!new_qtype) return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
|
||||
new_filter, new_qtype);
|
||||
|
||||
// If bias isn't None, it needs to be multiplied as well.
|
||||
if (!bias.getType().isa<NoneType>()) {
|
||||
rewriter.setInsertionPoint(fc_op);
|
||||
auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
|
||||
fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
|
||||
}
|
||||
|
||||
// Remove the tailing mul op.
|
||||
mul_op.replaceAllUsesWith(fc_op.getResult());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
|
||||
using FuseDepthwiseConv2DAndMulWithQDQs =
|
||||
FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
|
||||
|
||||
// Fuse Binary Op with following Affine operation.
|
||||
template <typename ConcreteType, typename AffineOpType>
|
||||
template <typename AffineOpType>
|
||||
struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
using OpRewritePattern<AffineOpType>::OpRewritePattern;
|
||||
|
||||
@ -469,37 +576,11 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
}
|
||||
};
|
||||
|
||||
class FuseBinaryOpToFollowingFullyConnected
|
||||
: public FuseBinaryOpToFollowingAffineOp<
|
||||
FuseBinaryOpToFollowingFullyConnected, FullyConnectedOp> {
|
||||
public:
|
||||
using BaseType =
|
||||
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingFullyConnected,
|
||||
FullyConnectedOp>;
|
||||
explicit FuseBinaryOpToFollowingFullyConnected(MLIRContext *context)
|
||||
: BaseType(context) {}
|
||||
};
|
||||
|
||||
class FuseBinaryOpToFollowingDepthwiseConv2D
|
||||
: public FuseBinaryOpToFollowingAffineOp<
|
||||
FuseBinaryOpToFollowingDepthwiseConv2D, DepthwiseConv2DOp> {
|
||||
public:
|
||||
using BaseType =
|
||||
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingDepthwiseConv2D,
|
||||
DepthwiseConv2DOp>;
|
||||
explicit FuseBinaryOpToFollowingDepthwiseConv2D(MLIRContext *context)
|
||||
: BaseType(context) {}
|
||||
};
|
||||
|
||||
class FuseBinaryOpToFollowingConv2D
|
||||
: public FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingConv2D,
|
||||
Conv2DOp> {
|
||||
public:
|
||||
using BaseType =
|
||||
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingConv2D, Conv2DOp>;
|
||||
explicit FuseBinaryOpToFollowingConv2D(MLIRContext *context)
|
||||
: BaseType(context) {}
|
||||
};
|
||||
using FuseBinaryOpToFollowingFullyConnected =
|
||||
FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
|
||||
using FuseBinaryOpToFollowingDepthwiseConv2D =
|
||||
FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
|
||||
using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
|
||||
|
||||
void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
@ -517,7 +598,9 @@ void Optimize::runOnFunction() {
|
||||
// Fuse the binary ops with the following ops.
|
||||
patterns.insert<FuseBinaryOpToFollowingConv2D,
|
||||
FuseBinaryOpToFollowingDepthwiseConv2D,
|
||||
FuseBinaryOpToFollowingFullyConnected>(ctx);
|
||||
FuseBinaryOpToFollowingFullyConnected,
|
||||
FuseConv2DAndMulWithQDQs, FuseDepthwiseConv2DAndMulWithQDQs>(
|
||||
ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user