From 79985f62fc83169a19760b7131cb4f7c1306469c Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Mon, 6 Apr 2020 16:00:53 -0700 Subject: [PATCH] Legalize tf.DepthwiseConv2dNative -> xla_hlo.conv Convert tf.DepthwiseConv2dNative -> xla_hlo.conv by setting feature_groups = input_features. PiperOrigin-RevId: 305137357 Change-Id: I472d162472649d68c236a3f55ef99bc8bde9d9ed --- .../compiler/mlir/xla/tests/legalize-tf.mlir | 16 +++++++ .../mlir/xla/transforms/legalize_tf.cc | 45 ++++++++++++++----- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index a462a7f4a1f..7c7f6f306cf 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2903,6 +2903,22 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) - return %0 : tensor<256x30x30x16xf32> } +// CHECK-LABEL: depthwiseconv_simple +func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> { + // CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> + // CHECK: "xla_hlo.conv"(%arg0, %[[RESHAPED_FILTER]]) + // CHECK: feature_group_count = 3 + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", + device = "", + dilations = [1, 1, 1, 1], + explicit_paddings = [], + padding = "VALID", + strides = [1, 1, 1, 1] + } : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> + return %0 : tensor<2x3x4x9xf32> +} + // CHECK-LABEL: conv_valid_padding func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { // CHECK: "xla_hlo.conv"(%arg0, %arg1) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index f8ec40aa42d..aa6ac85b4af 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -777,7 +777,7 @@ NamedAttribute GetConvDimensionNumbersAttr( // the paddings attribute anyway requires multiple source op attributes and // result op attributes. Defining it as declarative rewrite rule will introduce // some duplication in the C++ helper methods. -template +template class ConvertConv : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -868,7 +868,10 @@ class ConvertConv : public OpRewritePattern { int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); // TensorFlow convolution op verifies that the number of input channels is // divisible by the number of filter channels. - int64_t feature_group_count = input_channels / filter_channels; + // For depthwise convolution the feature_group_count argument would be set + // to the input feature dimension. + int64_t feature_group_count = + depthwise_conv ? input_channels : input_channels / filter_channels; auto feature_group_count_attr = rewriter.getNamedAttr( "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); @@ -881,6 +884,22 @@ class ConvertConv : public OpRewritePattern { "padding", DenseElementsAttr::get(paddings_ty, paddings)); SmallVector operands(op.getOperands()); + // Reshape the filter to {spatial_dims...., 1,in_channels * + // channel_multiplier} + if (depthwise_conv) { + auto filter_shape = filter_ty.getShape(); + llvm::SmallVector new_shape(filter_shape.size()); + for (int i = 0; i < num_spatial_dims; ++i) { + new_shape[i] = filter_shape[i]; + } + new_shape[num_spatial_dims] = 1; + new_shape[num_spatial_dims + 1] = + filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]; + operands[1] = rewriter.create( + op.getLoc(), + RankedTensorType::get(new_shape, filter_ty.getElementType()), + operands[1]); + } NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr}; @@ -891,7 +910,9 @@ class ConvertConv : public OpRewritePattern { }; using ConvertConv2D = ConvertConv; - +using ConvertDepthConv2D = + ConvertConv; // Converts BF16 FloorDiv op to have casting operators on either end as BF16 // division can result in strange behavior. // @@ -3767,15 +3788,15 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { patterns.insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D, - ConvertConv2DBackpropFilterOp, ConvertConv2DBackpropInputOp, - ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, - ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, - ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp, - ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp, - ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, - ConvertSoftmaxOp, + ConvertDepthConv2D, ConvertConv2DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp, + ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, + ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,