Legalize tf.DepthwiseConv2dNative -> xla_hlo.conv
Convert tf.DepthwiseConv2dNative -> xla_hlo.conv by setting feature_groups = input_features. PiperOrigin-RevId: 305137357 Change-Id: I472d162472649d68c236a3f55ef99bc8bde9d9ed
This commit is contained in:
parent
64f4a59d5e
commit
79985f62fc
@ -2903,6 +2903,22 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -
|
|||||||
return %0 : tensor<256x30x30x16xf32>
|
return %0 : tensor<256x30x30x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: depthwiseconv_simple
|
||||||
|
func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> {
|
||||||
|
// CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
|
||||||
|
// CHECK: "xla_hlo.conv"(%arg0, %[[RESHAPED_FILTER]])
|
||||||
|
// CHECK: feature_group_count = 3
|
||||||
|
%0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
|
||||||
|
data_format = "NHWC",
|
||||||
|
device = "",
|
||||||
|
dilations = [1, 1, 1, 1],
|
||||||
|
explicit_paddings = [],
|
||||||
|
padding = "VALID",
|
||||||
|
strides = [1, 1, 1, 1]
|
||||||
|
} : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32>
|
||||||
|
return %0 : tensor<2x3x4x9xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: conv_valid_padding
|
// CHECK-LABEL: conv_valid_padding
|
||||||
func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
|
func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||||
// CHECK: "xla_hlo.conv"(%arg0, %arg1)
|
// CHECK: "xla_hlo.conv"(%arg0, %arg1)
|
||||||
|
@ -777,7 +777,7 @@ NamedAttribute GetConvDimensionNumbersAttr(
|
|||||||
// the paddings attribute anyway requires multiple source op attributes and
|
// the paddings attribute anyway requires multiple source op attributes and
|
||||||
// result op attributes. Defining it as declarative rewrite rule will introduce
|
// result op attributes. Defining it as declarative rewrite rule will introduce
|
||||||
// some duplication in the C++ helper methods.
|
// some duplication in the C++ helper methods.
|
||||||
template <typename OpT, int num_spatial_dims>
|
template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
|
||||||
class ConvertConv : public OpRewritePattern<OpT> {
|
class ConvertConv : public OpRewritePattern<OpT> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<OpT>::OpRewritePattern;
|
using OpRewritePattern<OpT>::OpRewritePattern;
|
||||||
@ -868,7 +868,10 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
|||||||
int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
|
int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
|
||||||
// TensorFlow convolution op verifies that the number of input channels is
|
// TensorFlow convolution op verifies that the number of input channels is
|
||||||
// divisible by the number of filter channels.
|
// divisible by the number of filter channels.
|
||||||
int64_t feature_group_count = input_channels / filter_channels;
|
// For depthwise convolution the feature_group_count argument would be set
|
||||||
|
// to the input feature dimension.
|
||||||
|
int64_t feature_group_count =
|
||||||
|
depthwise_conv ? input_channels : input_channels / filter_channels;
|
||||||
auto feature_group_count_attr = rewriter.getNamedAttr(
|
auto feature_group_count_attr = rewriter.getNamedAttr(
|
||||||
"feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
|
"feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
|
||||||
|
|
||||||
@ -881,6 +884,22 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
|||||||
"padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
|
"padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
|
||||||
|
|
||||||
SmallVector<Value, 2> operands(op.getOperands());
|
SmallVector<Value, 2> operands(op.getOperands());
|
||||||
|
// Reshape the filter to {spatial_dims...., 1,in_channels *
|
||||||
|
// channel_multiplier}
|
||||||
|
if (depthwise_conv) {
|
||||||
|
auto filter_shape = filter_ty.getShape();
|
||||||
|
llvm::SmallVector<int64_t, 4> new_shape(filter_shape.size());
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
new_shape[i] = filter_shape[i];
|
||||||
|
}
|
||||||
|
new_shape[num_spatial_dims] = 1;
|
||||||
|
new_shape[num_spatial_dims + 1] =
|
||||||
|
filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1];
|
||||||
|
operands[1] = rewriter.create<xla_hlo::ReshapeOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
RankedTensorType::get(new_shape, filter_ty.getElementType()),
|
||||||
|
operands[1]);
|
||||||
|
}
|
||||||
NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
|
NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
|
||||||
dimension_numbers_attr, feature_group_count_attr,
|
dimension_numbers_attr, feature_group_count_attr,
|
||||||
batch_group_count_attr, paddings_attr};
|
batch_group_count_attr, paddings_attr};
|
||||||
@ -891,7 +910,9 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
|
using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
|
||||||
|
using ConvertDepthConv2D =
|
||||||
|
ConvertConv<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
|
||||||
|
/*depthwise_conv=*/true>;
|
||||||
// Converts BF16 FloorDiv op to have casting operators on either end as BF16
|
// Converts BF16 FloorDiv op to have casting operators on either end as BF16
|
||||||
// division can result in strange behavior.
|
// division can result in strange behavior.
|
||||||
//
|
//
|
||||||
@ -3767,15 +3788,15 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
|||||||
patterns.insert<
|
patterns.insert<
|
||||||
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
||||||
ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D,
|
ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D,
|
||||||
ConvertConv2DBackpropFilterOp, ConvertConv2DBackpropInputOp,
|
ConvertDepthConv2D, ConvertConv2DBackpropFilterOp,
|
||||||
ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp,
|
ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp,
|
||||||
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
|
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
|
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
|
||||||
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp,
|
ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp,
|
||||||
ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
|
ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp,
|
||||||
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
|
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||||
ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
|
ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp,
|
||||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user