Legalize tf.DepthwiseConv2dNative -> xla_hlo.conv
Convert tf.DepthwiseConv2dNative -> xla_hlo.conv by setting feature_groups = input_features. PiperOrigin-RevId: 305137357 Change-Id: I472d162472649d68c236a3f55ef99bc8bde9d9ed
This commit is contained in:
parent
64f4a59d5e
commit
79985f62fc
@ -2903,6 +2903,22 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -
|
||||
return %0 : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: depthwiseconv_simple
|
||||
func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> {
|
||||
// CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
|
||||
// CHECK: "xla_hlo.conv"(%arg0, %[[RESHAPED_FILTER]])
|
||||
// CHECK: feature_group_count = 3
|
||||
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
|
||||
data_format = "NHWC",
|
||||
device = "",
|
||||
dilations = [1, 1, 1, 1],
|
||||
explicit_paddings = [],
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32>
|
||||
return %0 : tensor<2x3x4x9xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: conv_valid_padding
|
||||
func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||
// CHECK: "xla_hlo.conv"(%arg0, %arg1)
|
||||
|
@ -777,7 +777,7 @@ NamedAttribute GetConvDimensionNumbersAttr(
|
||||
// the paddings attribute anyway requires multiple source op attributes and
|
||||
// result op attributes. Defining it as declarative rewrite rule will introduce
|
||||
// some duplication in the C++ helper methods.
|
||||
template <typename OpT, int num_spatial_dims>
|
||||
template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
|
||||
class ConvertConv : public OpRewritePattern<OpT> {
|
||||
public:
|
||||
using OpRewritePattern<OpT>::OpRewritePattern;
|
||||
@ -868,7 +868,10 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
||||
int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
|
||||
// TensorFlow convolution op verifies that the number of input channels is
|
||||
// divisible by the number of filter channels.
|
||||
int64_t feature_group_count = input_channels / filter_channels;
|
||||
// For depthwise convolution the feature_group_count argument would be set
|
||||
// to the input feature dimension.
|
||||
int64_t feature_group_count =
|
||||
depthwise_conv ? input_channels : input_channels / filter_channels;
|
||||
auto feature_group_count_attr = rewriter.getNamedAttr(
|
||||
"feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
|
||||
|
||||
@ -881,6 +884,22 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
||||
"padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
|
||||
|
||||
SmallVector<Value, 2> operands(op.getOperands());
|
||||
// Reshape the filter to {spatial_dims...., 1,in_channels *
|
||||
// channel_multiplier}
|
||||
if (depthwise_conv) {
|
||||
auto filter_shape = filter_ty.getShape();
|
||||
llvm::SmallVector<int64_t, 4> new_shape(filter_shape.size());
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
new_shape[i] = filter_shape[i];
|
||||
}
|
||||
new_shape[num_spatial_dims] = 1;
|
||||
new_shape[num_spatial_dims + 1] =
|
||||
filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1];
|
||||
operands[1] = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(new_shape, filter_ty.getElementType()),
|
||||
operands[1]);
|
||||
}
|
||||
NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
|
||||
dimension_numbers_attr, feature_group_count_attr,
|
||||
batch_group_count_attr, paddings_attr};
|
||||
@ -891,7 +910,9 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
||||
};
|
||||
|
||||
using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
|
||||
|
||||
using ConvertDepthConv2D =
|
||||
ConvertConv<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
|
||||
/*depthwise_conv=*/true>;
|
||||
// Converts BF16 FloorDiv op to have casting operators on either end as BF16
|
||||
// division can result in strange behavior.
|
||||
//
|
||||
@ -3767,15 +3788,15 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
patterns.insert<
|
||||
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
||||
ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D,
|
||||
ConvertConv2DBackpropFilterOp, ConvertConv2DBackpropInputOp,
|
||||
ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp,
|
||||
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
|
||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
|
||||
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp,
|
||||
ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
|
||||
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
|
||||
ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertDepthConv2D, ConvertConv2DBackpropFilterOp,
|
||||
ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp,
|
||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
|
||||
ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp,
|
||||
ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp,
|
||||
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||
ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp,
|
||||
ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
|
Loading…
Reference in New Issue
Block a user