[XLA/MLIR] Fix a bug in mhlo.convolution->tf.DepthwiseConv2dNative legalization.

Reference: For depthwise convolution the feature_group_count argument would be set to the input feature dimension, and the filter would be reshaped from [filter_height, filter_width, in_channels, channel_multiplier] to [filter_height, filter_width, 1, in_channels * channel_multiplier]. For more details, see tf.nn.depthwise_conv2d (https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).
PiperOrigin-RevId: 355502740
Change-Id: I3fe455bb00f20958d2608a38e851c9de9c186374
This commit is contained in:
A. Unique TensorFlower 2021-02-03 16:01:46 -08:00 committed by TensorFlower Gardener
parent cc9a7a30cf
commit 8b542c4304
2 changed files with 29 additions and 8 deletions

View File

@ -1613,15 +1613,17 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>
// CHECK-LABEL: func @convert_depthwise_conv2d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
// CHECK: %[[VAL_2:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32>
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
// CHECK: %[[CST:.*]] = constant dense<[3, 3, 207, 16]> : tensor<4xi64>
// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
// CHECK: %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK: return %[[VAL_3]] : tensor<1x8x8x16xf32>
// CHECK: }
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
(tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32>
return %0 : tensor<1x8x8x16xf32>
}

View File

@ -98,6 +98,11 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
input_feature_dimension);
int feature_group_count = conv_op.feature_group_count();
if (feature_group_count != 1 && feature_group_count != input_channels) {
// Group convolution is not supported yet.
return failure();
}
const bool is_depthwise_conv = input_channels == feature_group_count;
std::string padding;
@ -124,7 +129,7 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
}
CreateConvOp(conv_op, strides, padding, dilation, is_depthwise_conv,
rewriter);
input_channels, rewriter);
return success();
};
@ -153,12 +158,26 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
StringRef padding, ArrayRef<int64_t> dilation,
bool is_depthwise_conv,
bool is_depthwise_conv, int input_channels,
ConversionPatternRewriter &rewriter) const {
// TODO(chhe): To support more data formats other than "NHWC".
if (is_depthwise_conv) {
// Reshapes filter format to [filter_height, filter_width, in_channels,
// channel_multiplier] from HLO's [filter_height, filter_width, 1,
// in_channels * channel_multiplier] format.
auto filter_type = conv_op.rhs().getType().cast<ShapedType>();
llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
hlo_filter_shape.end());
tf_filter_shape[2] = input_channels;
tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
conv_op.rhs().getLoc(),
RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
conv_op.rhs());
rewriter.replaceOpWithNewOp<DepthwiseConv2dNativeOp>(
conv_op, conv_op.getType(), conv_op.lhs(), conv_op.rhs(),
conv_op, conv_op.getType(), conv_op.lhs(), reshaped_filter,
rewriter.getI64ArrayAttr(strides),
/*padding=*/rewriter.getStringAttr(padding),
/*explicit_paddings=*/rewriter.getI64ArrayAttr({}),