Add Constraint for fusing Add/Sub to Conv2D/DepthwiseConv2D and make sure that the operand shape can be fused with the bias.
PiperOrigin-RevId: 290786730 Change-Id: I593294c7fee147ec2d8abd6a9f4a757540f1acc8
This commit is contained in:
parent
4ce69d9a0f
commit
182520682f
tensorflow/compiler/mlir/lite
@ -78,10 +78,10 @@ func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
|
||||
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
@ -90,10 +90,10 @@ func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseSubIntoDepthwiseConv2d
|
||||
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<0.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
@ -131,10 +131,10 @@ func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
|
||||
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
|
@ -93,19 +93,13 @@ bool IsTailOfShape(Type type1, Type type2) {
|
||||
return std::equal(i1, e1, i2);
|
||||
}
|
||||
|
||||
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
|
||||
bool is_depthwise) {
|
||||
bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
|
||||
const ArrayRef<int64_t> elements_shape,
|
||||
bool is_depthwise) {
|
||||
// Make sure the val tensor has shape where all dimensions are 1 except
|
||||
// last one.
|
||||
// Also, val tensor must be of rank 1 or 4 or 0 (scalar).
|
||||
const auto elements = val.dyn_cast<DenseElementsAttr>();
|
||||
const auto elements_shape = elements.getType().getShape();
|
||||
const auto filter_elements = filter.dyn_cast<DenseElementsAttr>();
|
||||
const auto filter_shape = filter_elements.getType().getShape();
|
||||
const auto elements_rank = elements.getType().getRank();
|
||||
if (!elements || !filter_elements) {
|
||||
return false;
|
||||
}
|
||||
const auto elements_rank = elements_shape.size();
|
||||
for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
|
||||
if (elements_shape[i] != 1) return false;
|
||||
}
|
||||
@ -125,6 +119,30 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
|
||||
bool is_depthwise) {
|
||||
const auto elements = val.dyn_cast<DenseElementsAttr>();
|
||||
if (!elements) {
|
||||
return false;
|
||||
}
|
||||
const auto elements_shape = elements.getType().getShape();
|
||||
const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
|
||||
return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
|
||||
is_depthwise);
|
||||
}
|
||||
|
||||
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
|
||||
bool is_depthwise) {
|
||||
if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
|
||||
if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
|
||||
return CanFuseConvOrDepthwiseConvShapes(
|
||||
filter_elements.getType().getShape(), elements.getType().getShape(),
|
||||
is_depthwise);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Expand Attribute 'a' to 4D with all 1s except 1 dimension.
|
||||
// Which dimension depends on 'is_depthwise' is true or false.
|
||||
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
|
||||
|
@ -54,6 +54,10 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in
|
||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||
|
||||
|
||||
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
|
||||
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
@ -72,7 +76,8 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(HasOneUse $output)]>;
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
@ -86,14 +91,12 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(HasOneUse $output)]>;
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
}
|
||||
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
|
||||
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
|
||||
|
||||
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
|
||||
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
|
||||
|
||||
def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
|
||||
|
||||
def ExpandTo4DForDepthwiseConv: NativeCodeCall<
|
||||
|
Loading…
Reference in New Issue
Block a user