Add Constraint for fusing Add/Sub to Conv2D/DepthwiseConv2D and make sure that the operand shape can be fused with the bias.

PiperOrigin-RevId: 290786730
Change-Id: I593294c7fee147ec2d8abd6a9f4a757540f1acc8
This commit is contained in:
Karim Nosir 2020-01-21 11:31:20 -08:00 committed by TensorFlower Gardener
parent 4ce69d9a0f
commit 182520682f
3 changed files with 42 additions and 21 deletions
tensorflow/compiler/mlir/lite

View File

@ -78,10 +78,10 @@ func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x3
}
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%cst_0 = constant dense<1.5> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
@ -90,10 +90,10 @@ func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
}
// CHECK-LABEL: fuseSubIntoDepthwiseConv2d
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<0.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>
@ -131,10 +131,10 @@ func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
}
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%cst_0 = constant dense<1.5> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32>

View File

@ -93,19 +93,13 @@ bool IsTailOfShape(Type type1, Type type2) {
return std::equal(i1, e1, i2);
}
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
bool is_depthwise) {
bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
const ArrayRef<int64_t> elements_shape,
bool is_depthwise) {
// Make sure the val tensor has shape where all dimensions are 1 except
// last one.
// Also, val tensor must be of rank 1 or 4 or 0 (scalar).
const auto elements = val.dyn_cast<DenseElementsAttr>();
const auto elements_shape = elements.getType().getShape();
const auto filter_elements = filter.dyn_cast<DenseElementsAttr>();
const auto filter_shape = filter_elements.getType().getShape();
const auto elements_rank = elements.getType().getRank();
if (!elements || !filter_elements) {
return false;
}
const auto elements_rank = elements_shape.size();
for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
if (elements_shape[i] != 1) return false;
}
@ -125,6 +119,30 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
return true;
}
bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
bool is_depthwise) {
const auto elements = val.dyn_cast<DenseElementsAttr>();
if (!elements) {
return false;
}
const auto elements_shape = elements.getType().getShape();
const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
is_depthwise);
}
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
bool is_depthwise) {
if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
return CanFuseConvOrDepthwiseConvShapes(
filter_elements.getType().getShape(), elements.getType().getShape(),
is_depthwise);
}
}
return false;
}
// Expand Attribute 'a' to 4D with all 1s except 1 dimension.
// Which dimension depends on 'is_depthwise' is true or false.
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {

View File

@ -54,6 +54,10 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu1Op, TFL_AF_Relu1]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
@ -72,7 +76,8 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w),
[(HasOneUse $output)]>;
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $output)]>;
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
@ -86,14 +91,12 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
[(HasOneUse $output)]>;
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
def ExpandTo4DForDepthwiseConv: NativeCodeCall<