[XLA/MLIR] Fix a bug in mhlo.convolution->tf.DepthwiseConv2dNative legalization.
Reference: For depthwise convolution the feature_group_count argument would be set to the input feature dimension, and the filter would be reshaped from [filter_height, filter_width, in_channels, channel_multiplier] to [filter_height, filter_width, 1, in_channels * channel_multiplier]. For more details, see tf.nn.depthwise_conv2d (https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). PiperOrigin-RevId: 355502740 Change-Id: I3fe455bb00f20958d2608a38e851c9de9c186374
This commit is contained in:
parent
cc9a7a30cf
commit
8b542c4304
@ -1613,15 +1613,17 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>
|
||||
|
||||
// CHECK-LABEL: func @convert_depthwise_conv2d(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
// CHECK: return %[[VAL_2]] : tensor<1x8x8x16xf32>
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<[3, 3, 207, 16]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
// CHECK: return %[[VAL_3]] : tensor<1x8x8x16xf32>
|
||||
// CHECK: }
|
||||
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32>
|
||||
return %0 : tensor<1x8x8x16xf32>
|
||||
}
|
||||
|
||||
|
||||
@ -98,6 +98,11 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
|
||||
input_feature_dimension);
|
||||
int feature_group_count = conv_op.feature_group_count();
|
||||
|
||||
if (feature_group_count != 1 && feature_group_count != input_channels) {
|
||||
// Group convolution is not supported yet.
|
||||
return failure();
|
||||
}
|
||||
|
||||
const bool is_depthwise_conv = input_channels == feature_group_count;
|
||||
std::string padding;
|
||||
|
||||
@ -124,7 +129,7 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
|
||||
}
|
||||
|
||||
CreateConvOp(conv_op, strides, padding, dilation, is_depthwise_conv,
|
||||
rewriter);
|
||||
input_channels, rewriter);
|
||||
return success();
|
||||
};
|
||||
|
||||
@ -153,12 +158,26 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
|
||||
|
||||
void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
|
||||
StringRef padding, ArrayRef<int64_t> dilation,
|
||||
bool is_depthwise_conv,
|
||||
bool is_depthwise_conv, int input_channels,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// TODO(chhe): To support more data formats other than "NHWC".
|
||||
if (is_depthwise_conv) {
|
||||
// Reshapes filter format to [filter_height, filter_width, in_channels,
|
||||
// channel_multiplier] from HLO's [filter_height, filter_width, 1,
|
||||
// in_channels * channel_multiplier] format.
|
||||
auto filter_type = conv_op.rhs().getType().cast<ShapedType>();
|
||||
llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
|
||||
llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
|
||||
hlo_filter_shape.end());
|
||||
tf_filter_shape[2] = input_channels;
|
||||
tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
|
||||
auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
|
||||
conv_op.rhs().getLoc(),
|
||||
RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
|
||||
conv_op.rhs());
|
||||
|
||||
rewriter.replaceOpWithNewOp<DepthwiseConv2dNativeOp>(
|
||||
conv_op, conv_op.getType(), conv_op.lhs(), conv_op.rhs(),
|
||||
conv_op, conv_op.getType(), conv_op.lhs(), reshaped_filter,
|
||||
rewriter.getI64ArrayAttr(strides),
|
||||
/*padding=*/rewriter.getStringAttr(padding),
|
||||
/*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user