Fuse binary ops to the following affine ops.

The binary ops include add/sub/mul/div. The RHS should be scalar constant.
The affine ops include fully_connected, conv2d and depthwise_conv2d.

When the bias is updated for the transformation, we should assign different bias dim in the filter tensor.
conv2d and fully_connected is the first dimension, and depthwise_conv2d is the last one.

PiperOrigin-RevId: 274722860
This commit is contained in:
Feng Liu 2019-10-14 21:03:06 -07:00 committed by TensorFlower Gardener
parent c24e01b57f
commit ce54a5932d
2 changed files with 282 additions and 1 deletions

View File

@ -54,6 +54,36 @@ func @fuseSubIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: fuseAddIntoFollowingConv2d
func @fuseAddIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<f32>
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<256x32x32x3xf32>, tensor<f32>) -> tensor<256x32x32x3xf32>
%w = constant dense<1.0> : tensor<16x3x3x3xf32>
%bias = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%1 = "tfl.conv_2d"(%0, %w, %bias) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK-NEXT: %[[w:.*]] = constant dense<1.000000e+00> : tensor<16x3x3x3xf32>
// CHECK-NEXT: %[[b:.*]] = constant dense<[4.150000e+01, 4.250000e+01, 4.350000e+01, 4.450000e+01, 4.550000e+01, 4.650000e+01, 4.750000e+01, 4.850000e+01, 4.950000e+01, 5.050000e+01, 5.150000e+01, 5.250000e+01, 5.350000e+01, 5.450000e+01, 5.550000e+01, 5.650000e+01]> : tensor<16xf32>
// CHECK-NEXT: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: return %[[c]] : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: fuseSubIntoFollowingConv2d
func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<f32>
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<256x32x32x3xf32>, tensor<f32>) -> tensor<256x32x32x3xf32>
%w = constant dense<1.0> : tensor<16x3x3x3xf32>
%bias = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%1 = "tfl.conv_2d"(%0, %w, %bias) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK-NEXT: %[[w:.*]] = constant dense<1.000000e+00> : tensor<16x3x3x3xf32>
// CHECK-NEXT: %[[b:.*]] = constant dense<[-3.950000e+01, -3.850000e+01, -3.750000e+01, -3.650000e+01, -3.550000e+01, -3.450000e+01, -3.350000e+01, -3.250000e+01, -3.150000e+01, -3.050000e+01, -2.950000e+01, -2.850000e+01, -2.750000e+01, -2.650000e+01, -2.550000e+01, -2.450000e+01]> : tensor<16xf32>
// CHECK-NEXT: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[w]], %[[b]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: return %[[c]] : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
@ -78,6 +108,22 @@ func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst)
}
// CHECK-LABEL: fuseAddIntoFollowingDepthwiseConv2d
func @fuseAddIntoFollowingDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<f32>
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<256x32x32x3xf32>, tensor<f32>) -> tensor<256x32x32x3xf32>
%w = constant dense<1.0> : tensor<3x3x3x16xf32>
%bias = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%1 = "tfl.depthwise_conv_2d"(%0, %w, %bias) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
// CHECK-NEXT: %[[w:.*]] = constant dense<1.000000e+00> : tensor<3x3x3x16xf32>
// CHECK-NEXT: %[[b:.*]] = constant dense<[4.150000e+01, 4.250000e+01, 4.350000e+01, 4.450000e+01, 4.550000e+01, 4.650000e+01, 4.750000e+01, 4.850000e+01, 4.950000e+01, 5.050000e+01, 5.150000e+01, 5.250000e+01, 5.350000e+01, 5.450000e+01, 5.550000e+01, 5.650000e+01]> : tensor<16xf32>
// CHECK-NEXT: %[[dc:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[w]], %[[b]]) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: return %[[dc]] : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: fuseAddWithRelu6IntoConv2d
func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<1.5> : tensor<16xf32>
@ -137,6 +183,56 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
// CHECK: return %[[RES]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs
func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst2 = constant dense<1.5> : tensor<f32>
%0 = "tfl.add"(%arg0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4x2xf32>
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
%q = "tfl.quantize"(%cst0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0>>
%dq = "tfl.dequantize"(%q) : (tensor<2x2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%1 = "tfl.fully_connected"(%0, %dq, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
return %1 : tensor<4x2xf32>
// CHECK-NEXT: %[[w:.*]] = constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[b:.*]] = constant dense<[6.500000e+00, 1.250000e+01]> : tensor<2xf32>
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[w]])
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq]], %[[b]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
// CHECK-NEXT: return %[[fc]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseAddIntoFollowingFullyConnected
func @fuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst2 = constant dense<1.5> : tensor<f32>
%0 = "tfl.add"(%arg0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4x2xf32>
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%1 = "tfl.fully_connected"(%0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
return %1 : tensor<4x2xf32>
// CHECK-NEXT: %[[w:.*]] = constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[b:.*]] = constant dense<[6.500000e+00, 1.250000e+01]> : tensor<2xf32>
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
// CHECK-NEXT: return %[[fc]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseMulIntoFollowingFullyConnected
func @fuseMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst2 = constant dense<1.5> : tensor<f32>
%0 = "tfl.mul"(%arg0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4x2xf32>
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%1 = "tfl.fully_connected"(%0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
return %1 : tensor<4x2xf32>
// CHECK-NEXT: %[[b:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
// CHECK-NEXT: %[[w:.*]] = constant dense<{{\[}}[1.500000e+00, 3.000000e+00], [4.500000e+00, 6.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
// CHECK-NEXT: return %[[fc]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseMulIntoFullyConnectedBroadcast
func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> {
%cst0 = constant dense<[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32>

View File

@ -18,7 +18,11 @@ limitations under the License.
#include <climits>
#include <cstdint>
#include <functional>
#include <map>
#include <numeric>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@ -264,6 +268,185 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
}
};
// Fuse Binary Op with following Affine operation.
template <typename ConcreteType, typename AffineOpType>
struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
using OpRewritePattern<AffineOpType>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override {
// Binary op.
Operation *binary_op = fc_op.input()->getDefiningOp();
if (!binary_op || binary_op->getNumOperands() != 2)
return this->matchFailure();
// We only handle the cases the RHS is a scalar.
// TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
// the constant operands are on the RHS, we need to consider LHS constant
// operand if necessary.
DenseFPElementsAttr cst;
if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
return this->matchFailure();
if (cst.getNumElements() != 1) return this->matchFailure();
APFloat cst_value = *cst.float_value_begin();
// Affine op.
Value *filter = fc_op.filter();
Value *bias = fc_op.bias();
DenseFPElementsAttr filter_cst, bias_cst;
if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant.
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp());
if (!dq) return this->matchFailure();
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp());
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
return this->matchFailure();
}
filter = q.input();
}
if (!bias->getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&bias_cst)))
return this->matchFailure();
ShapedType filter_type = filter_cst.getType();
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
if (padding && padding.getValue() != "VALID") return this->matchFailure();
// The fusion of add/sub is actually applying the following
// transformation:
// w * (x + c) + b => w * x + (w * c + b)
// so we have to update the bias.
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
auto bias_and_slice =
ConcreteType::GetBiasDimAndSliceSize(filter_type.getShape());
int64_t bias_size = bias_and_slice.first;
int64_t slice_size = bias_and_slice.second;
ShapedType new_bias_type =
rewriter.getTensorType({bias_size}, filter_type.getElementType());
// The new bias should be a 1-D tensor with length equals to the bias
// dimension of the weight.
SmallVector<APFloat, 4> new_bias_values;
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros
new_bias_values.resize(bias_size, APFloat(0.0));
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
} else if (bias_cst.getNumElements() == bias_size) { // 1-d bias, copy it
new_bias_values.insert(new_bias_values.begin(),
bias_cst.float_value_begin(),
bias_cst.float_value_end());
} else {
return this->matchFailure();
}
int64_t flatten_index = 0;
for (auto fp_it = filter_cst.float_value_begin(),
fp_end = filter_cst.float_value_end();
fp_it != fp_end; ++fp_it) {
int bias_index = (flatten_index++ / slice_size) % bias_size;
new_bias_values[bias_index] =
new_bias_values[bias_index] + *fp_it * cst_value;
}
auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
auto new_bias_op =
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
fc_op.setOperand(0, binary_op->getOperand(0));
fc_op.setOperand(2, new_bias_op);
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) {
// The fusion of mul/div is actually applying the following
// transformation:
// w * (x ' c) + b => (w ' c) x + b
// so we have to update the weight.
bool is_mul = llvm::isa<MulOp>(binary_op);
auto new_fitler =
filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
});
// We recreate the constant op in case it is shared by the other ops. This
// might increase the model size.
auto new_filter_op = rewriter.create<ConstOp>(
fc_op.getLoc(), filter->getType(), new_fitler);
fc_op.setOperand(0, binary_op->getOperand(0));
if (fc_op.filter() != filter) {
// This filter goes through quantize and dequantize ops. Then we just
// need to update the weight to the quantize op.
filter->replaceAllUsesWith(new_filter_op);
} else {
// This filter doesn't go through quantize and dequantize ops, Then
// we update the weight of the affine op directly.
fc_op.setOperand(1, new_filter_op);
}
} else {
return this->matchFailure();
}
return this->matchSuccess();
}
};
class FuseBinaryOpToFollowingFullyConnected
: public FuseBinaryOpToFollowingAffineOp<
FuseBinaryOpToFollowingFullyConnected, FullyConnectedOp> {
public:
using BaseType =
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingFullyConnected,
FullyConnectedOp>;
explicit FuseBinaryOpToFollowingFullyConnected(MLIRContext *context)
: BaseType(context) {}
// The first dimension of the fully-connected weight needs to match the last
// dimension of the op result and also the (broadcasted) size of bias. Then
// the size of higher-dimensions is considered as the slice size.
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
ArrayRef<int64_t> filter_shape) {
int64_t depth =
std::accumulate(std::next(filter_shape.begin()), filter_shape.end(), 1,
std::multiplies<int64_t>());
return {filter_shape.front(), depth};
}
};
class FuseBinaryOpToFollowingDepthwiseConv2D
: public FuseBinaryOpToFollowingAffineOp<
FuseBinaryOpToFollowingDepthwiseConv2D, DepthwiseConv2DOp> {
public:
using BaseType =
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingDepthwiseConv2D,
DepthwiseConv2DOp>;
explicit FuseBinaryOpToFollowingDepthwiseConv2D(MLIRContext *context)
: BaseType(context) {}
// The last dimension of the depthwise conv 2d weight needs to match the last
// dimension of the op result and also the (broadcasted) size of bias. Then
// slice number is just 1.
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
ArrayRef<int64_t> filter_shape) {
return {filter_shape.back(), 1};
}
};
class FuseBinaryOpToFollowingConv2D
: public FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingConv2D,
Conv2DOp> {
public:
using BaseType =
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingConv2D, Conv2DOp>;
explicit FuseBinaryOpToFollowingConv2D(MLIRContext *context)
: BaseType(context) {}
// The first dimension of the conv 2d weight needs to match the last
// dimension of the op result and also the (broadcasted) size of bias. Then
// the size of higher-dimensions is considered as the slice size.
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
ArrayRef<int64_t> filter_shape) {
int64_t depth =
std::accumulate(std::next(filter_shape.begin()), filter_shape.end(), 1,
std::multiplies<int64_t>());
return {filter_shape.front(), depth};
}
};
void Optimize::runOnFunction() {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
@ -272,7 +455,9 @@ void Optimize::runOnFunction() {
// Add the generated patterns to the list.
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
FuseFullyConnectedAndMul>(ctx);
FuseFullyConnectedAndMul, FuseBinaryOpToFollowingConv2D,
FuseBinaryOpToFollowingDepthwiseConv2D,
FuseBinaryOpToFollowingFullyConnected>(ctx);
applyPatternsGreedily(func, patterns);
}