Legalize tf.DepthwiseConv2dNative -> xla_hlo.conv

Convert tf.DepthwiseConv2dNative -> xla_hlo.conv by setting feature_groups = input_features.

PiperOrigin-RevId: 305137357
Change-Id: I472d162472649d68c236a3f55ef99bc8bde9d9ed
This commit is contained in:
Ahmed S. Taei 2020-04-06 16:00:53 -07:00 committed by TensorFlower Gardener
parent 64f4a59d5e
commit 79985f62fc
2 changed files with 49 additions and 12 deletions

View File

@ -2903,6 +2903,22 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -
return %0 : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: depthwiseconv_simple
func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> {
// CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
// CHECK: "xla_hlo.conv"(%arg0, %[[RESHAPED_FILTER]])
// CHECK: feature_group_count = 3
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
data_format = "NHWC",
device = "",
dilations = [1, 1, 1, 1],
explicit_paddings = [],
padding = "VALID",
strides = [1, 1, 1, 1]
} : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32>
return %0 : tensor<2x3x4x9xf32>
}
// CHECK-LABEL: conv_valid_padding
func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
// CHECK: "xla_hlo.conv"(%arg0, %arg1)

View File

@ -777,7 +777,7 @@ NamedAttribute GetConvDimensionNumbersAttr(
// the paddings attribute anyway requires multiple source op attributes and
// result op attributes. Defining it as declarative rewrite rule will introduce
// some duplication in the C++ helper methods.
template <typename OpT, int num_spatial_dims>
template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
class ConvertConv : public OpRewritePattern<OpT> {
public:
using OpRewritePattern<OpT>::OpRewritePattern;
@ -868,7 +868,10 @@ class ConvertConv : public OpRewritePattern<OpT> {
int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
// TensorFlow convolution op verifies that the number of input channels is
// divisible by the number of filter channels.
int64_t feature_group_count = input_channels / filter_channels;
// For depthwise convolution the feature_group_count argument would be set
// to the input feature dimension.
int64_t feature_group_count =
depthwise_conv ? input_channels : input_channels / filter_channels;
auto feature_group_count_attr = rewriter.getNamedAttr(
"feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
@ -881,6 +884,22 @@ class ConvertConv : public OpRewritePattern<OpT> {
"padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
SmallVector<Value, 2> operands(op.getOperands());
// Reshape the filter to {spatial_dims...., 1,in_channels *
// channel_multiplier}
if (depthwise_conv) {
auto filter_shape = filter_ty.getShape();
llvm::SmallVector<int64_t, 4> new_shape(filter_shape.size());
for (int i = 0; i < num_spatial_dims; ++i) {
new_shape[i] = filter_shape[i];
}
new_shape[num_spatial_dims] = 1;
new_shape[num_spatial_dims + 1] =
filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1];
operands[1] = rewriter.create<xla_hlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(new_shape, filter_ty.getElementType()),
operands[1]);
}
NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
dimension_numbers_attr, feature_group_count_attr,
batch_group_count_attr, paddings_attr};
@ -891,7 +910,9 @@ class ConvertConv : public OpRewritePattern<OpT> {
};
using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
using ConvertDepthConv2D =
ConvertConv<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
/*depthwise_conv=*/true>;
// Converts BF16 FloorDiv op to have casting operators on either end as BF16
// division can result in strange behavior.
//
@ -3767,15 +3788,15 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
patterns.insert<
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D,
ConvertConv2DBackpropFilterOp, ConvertConv2DBackpropInputOp,
ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp,
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp,
ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertDepthConv2D, ConvertConv2DBackpropFilterOp,
ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp,
ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp,
ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,