Add tf.MaxPool3DGrad -> xla_hlo.select_and_scatter/xla_hlo.compare/xla_hlo.add legalization.
This reuses the current legalization for MaxPool{2D}Grad. PiperOrigin-RevId: 307636800 Change-Id: If3dcab69c9a0ebc7eeaf1140387f5030934cfa54
This commit is contained in:
parent
7f8add7ab8
commit
121ee5d9b3
|
@ -1398,6 +1398,25 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te
|
||||||
return %result : tensor<10x24x24x64xf32>
|
return %result : tensor<10x24x24x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @max_pool_3d_grad_valid
|
||||||
|
// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32>
|
||||||
|
func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> {
|
||||||
|
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
|
||||||
|
// CHECK: %[[SELECT_RESULT:.*]] = "xla_hlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> ()
|
||||||
|
// CHECK: }, {
|
||||||
|
// CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
|
||||||
|
// CHECK: %[[SELECT_RESULT:.*]] = xla_hlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
|
||||||
|
// CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
|
||||||
|
// CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
%result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32>
|
||||||
|
return %result : tensor<10x8x24x24x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @max_pool_grad_same
|
// CHECK-LABEL: @max_pool_grad_same
|
||||||
func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> {
|
func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> {
|
||||||
// CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
|
// CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
|
||||||
|
@ -1410,6 +1429,13 @@ func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tenso
|
||||||
return %result : tensor<2x13x25x7xf32>
|
return %result : tensor<2x13x25x7xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @max_pool_3d_grad_same
|
||||||
|
func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> {
|
||||||
|
// CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
|
||||||
|
%result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32>
|
||||||
|
return %result : tensor<2x8x13x25x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// OneHot op legalizations.
|
// OneHot op legalizations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1352,7 +1352,7 @@ static DenseIntElementsAttr GetReduceWindowPadding(
|
||||||
flatten_paddings);
|
flatten_paddings);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
|
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
|
||||||
// dimensions with add as the reduction function. The reduction result is
|
// dimensions with add as the reduction function. The reduction result is
|
||||||
// then divided by the number of elements in the window.
|
// then divided by the number of elements in the window.
|
||||||
class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
||||||
|
@ -2705,23 +2705,25 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
|
template <typename OpTy, int num_dims>
|
||||||
|
class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::MaxPoolGradOp op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
Type element_type =
|
Type element_type =
|
||||||
op.orig_input().getType().cast<TensorType>().getElementType();
|
op.orig_input().getType().template cast<TensorType>().getElementType();
|
||||||
|
|
||||||
// Compute paddings using the original input and kernel shape and strides.
|
// Compute paddings using the original input and kernel shape and strides.
|
||||||
// Here, ReduceWindow op as used as the MaxPool op is lowered to the
|
// Here, ReduceWindow op as used as the MaxPool op is lowered to the
|
||||||
// ReduceWindow op.
|
// ReduceWindow op.
|
||||||
auto input_ty = op.orig_input().getType().dyn_cast<RankedTensorType>();
|
auto input_ty =
|
||||||
|
op.orig_input().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty) return failure();
|
if (!input_ty) return failure();
|
||||||
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<4>(
|
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<num_dims>(
|
||||||
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
||||||
|
|
||||||
auto result = rewriter.create<SelectAndScatterOp>(
|
auto result = rewriter.create<SelectAndScatterOp>(
|
||||||
|
@ -2752,6 +2754,11 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ConvertMaxPool2DGradOp =
|
||||||
|
ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
|
||||||
|
using ConvertMaxPool3DGradOp =
|
||||||
|
ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
|
||||||
|
|
||||||
// Converts tf.Conv?DBackpropInputOp into:
|
// Converts tf.Conv?DBackpropInputOp into:
|
||||||
// %rev_filter = "xla_hlo.reverse"(%filter)
|
// %rev_filter = "xla_hlo.reverse"(%filter)
|
||||||
// %result = "xla_hlo.convolution"(%out_backprop, %rev_filter)
|
// %result = "xla_hlo.convolution"(%out_backprop, %rev_filter)
|
||||||
|
@ -3925,10 +3932,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||||
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
|
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
|
||||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
|
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
|
||||||
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
|
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
|
||||||
ConvertMaxPool3DOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
|
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
|
||||||
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
|
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||||
ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
|
ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp,
|
||||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||||
|
|
Loading…
Reference in New Issue