Add tf.MaxPool3DGrad -> xla_hlo.select_and_scatter/xla_hlo.compare/xla_hlo.add legalization.

This reuses the current legalization for MaxPool{2D}Grad.

PiperOrigin-RevId: 307636800
Change-Id: If3dcab69c9a0ebc7eeaf1140387f5030934cfa54
This commit is contained in:
Andy Ly 2020-04-21 10:49:50 -07:00 committed by TensorFlower Gardener
parent 7f8add7ab8
commit 121ee5d9b3
2 changed files with 44 additions and 11 deletions

View File

@ -1398,6 +1398,25 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te
return %result : tensor<10x24x24x64xf32>
}
// CHECK-LABEL: @max_pool_3d_grad_valid
// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32>
func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
// CHECK: %[[SELECT_RESULT:.*]] = "xla_hlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
// CHECK: %[[SELECT_RESULT:.*]] = xla_hlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
// CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32>
// CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32>
// CHECK: }
%result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32>
return %result : tensor<10x8x24x24x64xf32>
}
// CHECK-LABEL: @max_pool_grad_same
func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> {
// CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
@ -1410,6 +1429,13 @@ func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tenso
return %result : tensor<2x13x25x7xf32>
}
// CHECK-LABEL: @max_pool_3d_grad_same
func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> {
// CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
%result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32>
return %result : tensor<2x8x13x25x7xf32>
}
//===----------------------------------------------------------------------===//
// OneHot op legalizations.
//===----------------------------------------------------------------------===//

View File

@ -1352,7 +1352,7 @@ static DenseIntElementsAttr GetReduceWindowPadding(
flatten_paddings);
}
// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
// dimensions with add as the reduction function. The reduction result is
// then divided by the number of elements in the window.
class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
@ -2705,23 +2705,25 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
}
};
class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
template <typename OpTy, int num_dims>
class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern::OpRewritePattern;
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::MaxPoolGradOp op,
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type element_type =
op.orig_input().getType().cast<TensorType>().getElementType();
op.orig_input().getType().template cast<TensorType>().getElementType();
// Compute paddings using the original input and kernel shape and strides.
// Here, ReduceWindow op as used as the MaxPool op is lowered to the
// ReduceWindow op.
auto input_ty = op.orig_input().getType().dyn_cast<RankedTensorType>();
auto input_ty =
op.orig_input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return failure();
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<4>(
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<num_dims>(
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
auto result = rewriter.create<SelectAndScatterOp>(
@ -2752,6 +2754,11 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
}
};
using ConvertMaxPool2DGradOp =
ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
using ConvertMaxPool3DGradOp =
ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
// Converts tf.Conv?DBackpropInputOp into:
// %rev_filter = "xla_hlo.reverse"(%filter)
// %result = "xla_hlo.convolution"(%out_backprop, %rev_filter)
@ -3925,10 +3932,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
ConvertMaxPool3DOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp,
ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,