diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 2ce675a4a30..aae23ee5c56 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1398,6 +1398,25 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te return %result : tensor<10x24x24x64xf32> } +// CHECK-LABEL: @max_pool_3d_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> +func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = "xla_hlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = xla_hlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> + // CHECK: } + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> + return %result : tensor<10x8x24x24x64xf32> +} + // CHECK-LABEL: @max_pool_grad_same func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> @@ -1410,6 +1429,13 @@ func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tenso return %result : tensor<2x13x25x7xf32> } +// CHECK-LABEL: @max_pool_3d_grad_same +func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> + return %result : tensor<2x8x13x25x7xf32> +} + //===----------------------------------------------------------------------===// // OneHot op legalizations. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 6feb815917b..44aacd85518 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1352,7 +1352,7 @@ static DenseIntElementsAttr GetReduceWindowPadding( flatten_paddings); } -// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window +// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // dimensions with add as the reduction function. The reduction result is // then divided by the number of elements in the window. class ConvertAvgPoolOp : public OpRewritePattern { @@ -2705,23 +2705,25 @@ class ConvertTileOp : public OpRewritePattern { } }; -class ConvertMaxPoolGradOp : public OpRewritePattern { +template +class ConvertMaxPoolGradOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::MaxPoolGradOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type element_type = - op.orig_input().getType().cast().getElementType(); + op.orig_input().getType().template cast().getElementType(); // Compute paddings using the original input and kernel shape and strides. // Here, ReduceWindow op as used as the MaxPool op is lowered to the // ReduceWindow op. - auto input_ty = op.orig_input().getType().dyn_cast(); + auto input_ty = + op.orig_input().getType().template dyn_cast(); if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<4>( + DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto result = rewriter.create( @@ -2752,6 +2754,11 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { } }; +using ConvertMaxPool2DGradOp = + ConvertMaxPoolGradOp; +using ConvertMaxPool3DGradOp = + ConvertMaxPoolGradOp; + // Converts tf.Conv?DBackpropInputOp into: // %rev_filter = "xla_hlo.reverse"(%filter) // %result = "xla_hlo.convolution"(%out_backprop, %rev_filter) @@ -3925,10 +3932,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, - ConvertMaxPool3DOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp, - ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp, - ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, - ConvertSoftmaxOp, + ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, + ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,