Support SAME padding, non-default data formats and AvgPool3D in AvgPool lowering
Also added unit tests and enabled all related functional tests for pooling ops. PiperOrigin-RevId: 321433980 Change-Id: Ia175333d81398deadfcc9e18bb7eeabe8e0c67de
This commit is contained in:
parent
806a053eb5
commit
822b035fd1
@ -745,6 +745,30 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> {
|
||||||
|
let summary = "Performs 3D average pooling on the input.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Each entry in `output` is the mean of the corresponding size `ksize`
|
||||||
|
window in `value`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_FpTensor:$value,
|
||||||
|
|
||||||
|
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
|
||||||
|
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
|
||||||
|
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
|
||||||
|
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_FpTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
|
def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
|
||||||
let summary = "Computes gradients of the average pooling function.";
|
let summary = "Computes gradients of the average pooling function.";
|
||||||
|
|
||||||
|
@ -980,7 +980,7 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
|
|||||||
|
|
||||||
// SparseMatMul where one operand needs to be transposed and the other one not.
|
// SparseMatMul where one operand needs to be transposed and the other one not.
|
||||||
//
|
//
|
||||||
// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose
|
// CHECK-LABEL: @test_sparse_mat_mul_with_transpose
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
|
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
|
||||||
// CHECK-SAME: -> tensor<3x5xf32>
|
// CHECK-SAME: -> tensor<3x5xf32>
|
||||||
@ -990,7 +990,6 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
|
|||||||
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
|
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
|
||||||
// CHECK-SAME: -> tensor<3x5xf32>
|
// CHECK-SAME: -> tensor<3x5xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
// CHECK: }
|
|
||||||
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
|
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
|
||||||
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
|
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
|
||||||
return %0: tensor<3x5xf32>
|
return %0: tensor<3x5xf32>
|
||||||
@ -998,7 +997,7 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5
|
|||||||
|
|
||||||
// SparseMatMul where one operand needs to be casted and the other one not.
|
// SparseMatMul where one operand needs to be casted and the other one not.
|
||||||
//
|
//
|
||||||
// CHECK-LABEL: func @test_sparse_mat_mul_with_cast
|
// CHECK-LABEL: @test_sparse_mat_mul_with_cast
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
|
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
|
||||||
// CHECK-SAME: -> tensor<3x5xf32>
|
// CHECK-SAME: -> tensor<3x5xf32>
|
||||||
@ -1007,7 +1006,6 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5
|
|||||||
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
|
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
|
||||||
// CHECK-SAME: -> tensor<3x5xf32>
|
// CHECK-SAME: -> tensor<3x5xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
// CHECK: }
|
|
||||||
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
|
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
|
||||||
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
|
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
|
||||||
return %0: tensor<3x5xf32>
|
return %0: tensor<3x5xf32>
|
||||||
@ -3868,36 +3866,167 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
|
|||||||
// tf.AvgPool legalization
|
// tf.AvgPool legalization
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// CHECK-LABEL: avgpool_valid_padding
|
// CHECK-LABEL: @avgpool_valid_padding
|
||||||
// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x20x7xf16>
|
// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16>
|
||||||
func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> {
|
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32>
|
||||||
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x20x7xf16>) -> tensor<2x12x20x7xf32>
|
// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: [[INIT:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
|
||||||
// CHECK: [[REDUCE:%.+]] = "mhlo.reduce_window"([[CONV32]], [[INIT]]) ( {
|
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
|
||||||
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
|
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
|
||||||
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
|
// CHECK: "mhlo.return"([[ADD]])
|
||||||
// CHECK: "mhlo.return"([[ADD]])
|
// CHECK: })
|
||||||
// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
|
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]>
|
||||||
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]>
|
||||||
// CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
|
// CHECK-SAME: -> tensor<2x3x5x7xf32>
|
||||||
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16>
|
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
// CHECK: return [[CONV16]]
|
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
|
||||||
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16>
|
// CHECK-SAME: broadcast_dimensions = dense<>
|
||||||
|
// CHECK-SAME: -> tensor<2x3x5x7xf32>
|
||||||
|
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
|
||||||
|
// CHECK-SAME: -> tensor<2x3x5x7xf16>
|
||||||
|
// CHECK: return [[CONV16]]
|
||||||
|
func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> {
|
||||||
|
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16>
|
||||||
return %0 : tensor<2x3x5x7xf16>
|
return %0 : tensor<2x3x5x7xf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: avgpool_same_padding
|
// CHECK-LABEL: @avgpool_3d_valid_padding
|
||||||
func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> {
|
// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16>
|
||||||
// CHECK: tf.AvgPool
|
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32>
|
||||||
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
|
// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
return %0 : tensor<2x4x7x7xf32>
|
// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
|
||||||
|
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
|
||||||
|
// CHECK: "mhlo.return"([[ADD]])
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]>
|
||||||
|
// CHECK-SAME: -> tensor<2x4x3x5x7xf32>
|
||||||
|
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
|
||||||
|
// CHECK-SAME: broadcast_dimensions = dense<>
|
||||||
|
// CHECK-SAME: -> tensor<2x4x3x5x7xf32>
|
||||||
|
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
|
||||||
|
// CHECK-SAME: -> tensor<2x4x3x5x7xf16>
|
||||||
|
// CHECK: return [[CONV16]]
|
||||||
|
func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> {
|
||||||
|
%0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16>
|
||||||
|
return %0 : tensor<2x4x3x5x7xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @avgpool_nchw_format
|
||||||
|
// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16>
|
||||||
|
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
|
||||||
|
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
|
||||||
|
// CHECK: "mhlo.return"([[ADD]])
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]>
|
||||||
|
// CHECK-SAME: -> tensor<2x7x3x5xf32>
|
||||||
|
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
|
||||||
|
// CHECK-SAME: broadcast_dimensions = dense<>
|
||||||
|
// CHECK-SAME: -> tensor<2x7x3x5xf32>
|
||||||
|
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
|
||||||
|
// CHECK-SAME: -> tensor<2x7x3x5xf16>
|
||||||
|
// CHECK: return [[CONV16]]
|
||||||
|
func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> {
|
||||||
|
%0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16>
|
||||||
|
return %0 : tensor<2x7x3x5xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @avgpool_3d_ncdhw_format
|
||||||
|
// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16>
|
||||||
|
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
|
||||||
|
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
|
||||||
|
// CHECK: "mhlo.return"([[ADD]])
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]>
|
||||||
|
// CHECK-SAME: -> tensor<2x7x4x3x5xf32>
|
||||||
|
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
|
||||||
|
// CHECK-SAME: broadcast_dimensions = dense<>
|
||||||
|
// CHECK-SAME: -> tensor<2x7x4x3x5xf32>
|
||||||
|
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
|
||||||
|
// CHECK-SAME: -> tensor<2x7x4x3x5xf16>
|
||||||
|
// CHECK: return [[CONV16]]
|
||||||
|
func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> {
|
||||||
|
%0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16>
|
||||||
|
return %0 : tensor<2x7x4x3x5xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @avgpool_same_padding(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
|
||||||
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||||
|
// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||||
|
// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]>
|
||||||
|
// CHECK-SAME: -> tensor<2x4x6x7xf32>
|
||||||
|
// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32>
|
||||||
|
// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
|
||||||
|
// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
|
||||||
|
// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]>
|
||||||
|
// CHECK-SAME: -> tensor<2x4x6x7xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> {
|
||||||
|
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
|
||||||
|
return %0 : tensor<2x4x6x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @avgpool_3d_same_padding(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
|
||||||
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||||
|
// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||||
|
// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]>
|
||||||
|
// CHECK-SAME: -> tensor<2x4x4x6x7xf32>
|
||||||
|
// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32>
|
||||||
|
// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
|
||||||
|
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
|
||||||
|
// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
|
||||||
|
// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
|
||||||
|
// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]>
|
||||||
|
// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]>
|
||||||
|
// CHECK-SAME: -> tensor<2x4x4x6x7xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]]
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> {
|
||||||
|
%0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
|
||||||
|
return %0 : tensor<2x4x4x6x7xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AvgPoolGrad op legalizations.
|
// AvgPoolGrad op legalizations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_grad_valid_padding(
|
// CHECK-LABEL: @avgpool_grad_valid_padding(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
@ -3929,7 +4058,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24
|
|||||||
return %result : tensor<10x24x32x64xf32>
|
return %result : tensor<10x24x32x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_3d_grad_valid_padding(
|
// CHECK-LABEL: @avgpool_3d_grad_valid_padding(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
@ -3958,7 +4087,7 @@ func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<
|
|||||||
return %result : tensor<10x8x24x32x64xf32>
|
return %result : tensor<10x8x24x32x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_grad_same_padding(
|
// CHECK-LABEL: @avgpool_grad_same_padding(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
|
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
|
||||||
@ -3997,7 +4126,7 @@ func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9x
|
|||||||
return %result : tensor<2x13x25x9xf32>
|
return %result : tensor<2x13x25x9xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_3d_grad_same_padding(
|
// CHECK-LABEL: @avgpool_3d_grad_same_padding(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
|
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
|
||||||
@ -4035,7 +4164,7 @@ func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x1
|
|||||||
return %result : tensor<2x8x13x25x9xf32>
|
return %result : tensor<2x8x13x25x9xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_grad_nchw_format(
|
// CHECK-LABEL: @avgpool_grad_nchw_format(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
|
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
|
||||||
@ -4074,7 +4203,7 @@ func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf
|
|||||||
return %result : tensor<2x9x13x25xf32>
|
return %result : tensor<2x9x13x25xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format(
|
// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
|
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
|
||||||
@ -4112,7 +4241,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8
|
|||||||
return %result : tensor<2x9x8x13x25xf32>
|
return %result : tensor<2x9x8x13x25xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avgpool_grad_bf16(
|
// CHECK-LABEL: @avgpool_grad_bf16(
|
||||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
|
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
|
||||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
|
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
|
||||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
|
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
|
||||||
|
@ -1770,30 +1770,93 @@ static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
|
|||||||
flatten_paddings);
|
flatten_paddings);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function for dividing each entry of `pooled` by the count of its
|
||||||
|
// corresponding window, i.e., the number of non-padding entries of the window
|
||||||
|
// which an `AvgPool` operation performed on an `input_shape`-tensor would map
|
||||||
|
// to this entry, depending on `ksize` and `strides`. This function is used for
|
||||||
|
// `AvgPool` and `AvgPoolGrad` legalizations.
|
||||||
|
// `zero` is passed as a parameter because it can be reused from caller level.
|
||||||
|
// `pooled` must have `RankedTensorType`.
|
||||||
|
template <typename OpTy, int num_dims>
|
||||||
|
Operation *AvgPoolDivideByCount(
|
||||||
|
Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
|
||||||
|
const SmallVector<int64_t, num_dims> &ksize,
|
||||||
|
const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
RankedTensorType pooled_type =
|
||||||
|
pooled.getType().template cast<RankedTensorType>();
|
||||||
|
Type element_type = pooled_type.getElementType();
|
||||||
|
Operation *result = nullptr;
|
||||||
|
RankedTensorType orig_input_type =
|
||||||
|
RankedTensorType::get(input_shape, element_type);
|
||||||
|
|
||||||
|
if (op.padding() == "VALID") {
|
||||||
|
// All window counts are equal here because we don't have padding
|
||||||
|
// (each entry of `pooled` corresponds to a window that consists of
|
||||||
|
// original input entries only).
|
||||||
|
int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
|
||||||
|
std::multiplies<int64_t>());
|
||||||
|
// Divide `pooled` by window counts.
|
||||||
|
Value divisor =
|
||||||
|
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
|
||||||
|
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||||
|
result = rewriter.create<chlo::BroadcastDivOp>(
|
||||||
|
loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
|
||||||
|
} else {
|
||||||
|
assert(op.padding() == "SAME");
|
||||||
|
// For SAME padding, only original entries that contributed to a window
|
||||||
|
// are counted for the average of this window, not padded entries.
|
||||||
|
|
||||||
|
// Build all-ones tensor of same shape as the original input.
|
||||||
|
ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
|
||||||
|
auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
|
||||||
|
|
||||||
|
// Get padding for the input.
|
||||||
|
DenseIntElementsAttr input_padding_attr =
|
||||||
|
GetReduceWindowPaddingAsAttr<num_dims>(
|
||||||
|
input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
|
||||||
|
|
||||||
|
// Count the 1's in each window, using the same padding as for the input,
|
||||||
|
// which gives us the window counts by which `pooled` needs to be divided.
|
||||||
|
auto divisor = rewriter.create<ReduceWindowOp>(
|
||||||
|
loc, pooled_type,
|
||||||
|
/*operand=*/all_ones_tensor,
|
||||||
|
/*init_value=*/zero,
|
||||||
|
/*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
|
||||||
|
/*window_strides=*/GetI64ElementsAttr(op.strides()),
|
||||||
|
/*base_dilations=*/DenseIntElementsAttr(),
|
||||||
|
/*window_dilations=*/DenseIntElementsAttr(),
|
||||||
|
/*padding=*/input_padding_attr);
|
||||||
|
BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
|
||||||
|
|
||||||
|
// Divide `pooled` by window counts.
|
||||||
|
result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled, divisor);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
|
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
|
||||||
// dimensions with add as the reduction function. The reduction result is
|
// dimensions with add as the reduction function. The reduction result is
|
||||||
// then divided by the number of elements in the window.
|
// then divided by the number of elements in the window.
|
||||||
class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
template <typename OpTy, int num_dims>
|
||||||
|
class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::AvgPoolOp op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
auto input_type =
|
||||||
|
op.value().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return failure();
|
if (!input_type) return failure();
|
||||||
|
|
||||||
// TODO(b/147217034): support other data formats.
|
|
||||||
if (!IsDefaultDataFormat(op.data_format())) return failure();
|
|
||||||
// TODO(b/147217034): support "SAME" padding.
|
|
||||||
if (op.padding() != "VALID") return failure();
|
|
||||||
|
|
||||||
// We will do accumulation first; use a larger bitwidth if suitable.
|
// We will do accumulation first; use a larger bitwidth if suitable.
|
||||||
Type input_element_type = input_type.getElementType();
|
Type input_element_type = input_type.getElementType();
|
||||||
Type sum_element_type = GetSumAccumulationType(input_element_type);
|
Type sum_element_type = GetSumAccumulationType(input_element_type);
|
||||||
Type result_type;
|
Type result_type;
|
||||||
|
|
||||||
// The result type for reduction and division with the proper element type.
|
// The result type for reduction and division with the proper element type.
|
||||||
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>())
|
if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
|
||||||
result_type =
|
result_type =
|
||||||
RankedTensorType::get(ranked_type.getShape(), sum_element_type);
|
RankedTensorType::get(ranked_type.getShape(), sum_element_type);
|
||||||
else
|
else
|
||||||
@ -1809,9 +1872,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
|||||||
// Create the tf.ReduceWindow op.
|
// Create the tf.ReduceWindow op.
|
||||||
Value init =
|
Value init =
|
||||||
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
|
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
|
||||||
DenseIntElementsAttr paddings_attr =
|
DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
|
||||||
GetReduceWindowPaddingAsAttr<4>(input_type.getShape(), op.ksize(),
|
input_type.getShape(), op.ksize(), op.strides(), op.padding(),
|
||||||
op.strides(), op.padding(), &rewriter);
|
&rewriter);
|
||||||
auto reduce = rewriter.create<ReduceWindowOp>(
|
auto reduce = rewriter.create<ReduceWindowOp>(
|
||||||
op.getLoc(), result_type, input_value, init,
|
op.getLoc(), result_type, input_value, init,
|
||||||
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
|
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
|
||||||
@ -1821,19 +1884,17 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
|||||||
|
|
||||||
// Count the number of elements in the window. The following calculation
|
// Count the number of elements in the window. The following calculation
|
||||||
// is only valid for no paddings.
|
// is only valid for no paddings.
|
||||||
SmallVector<int64_t, 4> ksize;
|
SmallVector<int64_t, num_dims> input_shape(
|
||||||
|
llvm::to_vector<num_dims>(input_type.getShape()));
|
||||||
|
SmallVector<int64_t, num_dims> ksize, strides;
|
||||||
GetI64ArrayAttrValues(op.ksize(), &ksize);
|
GetI64ArrayAttrValues(op.ksize(), &ksize);
|
||||||
int64_t count = std::accumulate(ksize.begin(), ksize.end(), 1,
|
GetI64ArrayAttrValues(op.strides(), &strides);
|
||||||
std::multiplies<int64_t>());
|
|
||||||
|
|
||||||
// Divide by the number of elements in the window.
|
Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
|
||||||
Value divisor =
|
reduce.getResult(), input_shape, ksize, strides, op, init, rewriter);
|
||||||
GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
|
|
||||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
|
||||||
Value result = rewriter.create<chlo::BroadcastDivOp>(
|
|
||||||
op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims);
|
|
||||||
|
|
||||||
// Convert back if we enlarged the element type's bitwidth.
|
// Convert back if we enlarged the element type's bitwidth.
|
||||||
|
Value result = result_op->getOpResult(0);
|
||||||
if (input_element_type != sum_element_type)
|
if (input_element_type != sum_element_type)
|
||||||
result =
|
result =
|
||||||
rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
|
rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
|
||||||
@ -1843,6 +1904,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
|
||||||
|
using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
|
||||||
|
|
||||||
// `AvgPoolGradOp` is converted to the following operations:
|
// `AvgPoolGradOp` is converted to the following operations:
|
||||||
// 1. Divide each entry of the output gradient (the gradient for the previous
|
// 1. Divide each entry of the output gradient (the gradient for the previous
|
||||||
// layer in backpropagation order) by the count of the corresponding window
|
// layer in backpropagation order) by the count of the corresponding window
|
||||||
@ -1916,59 +1980,13 @@ class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
|
|||||||
auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
|
auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
|
||||||
DimVector orig_input_shape(orig_input_shape_values.begin(),
|
DimVector orig_input_shape(orig_input_shape_values.begin(),
|
||||||
orig_input_shape_values.end());
|
orig_input_shape_values.end());
|
||||||
RankedTensorType orig_input_type =
|
|
||||||
RankedTensorType::get(orig_input_shape, element_type);
|
|
||||||
DimVector ksize, strides;
|
DimVector ksize, strides;
|
||||||
GetI64ArrayAttrValues(op.ksize(), &ksize);
|
GetI64ArrayAttrValues(op.ksize(), &ksize);
|
||||||
GetI64ArrayAttrValues(op.strides(), &strides);
|
GetI64ArrayAttrValues(op.strides(), &strides);
|
||||||
Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
|
Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
|
||||||
|
|
||||||
Operation *out_grad_divided = nullptr;
|
auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
|
||||||
if (op.padding() == "VALID") {
|
out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
|
||||||
// All window counts are equal here because we don't have padding
|
|
||||||
// (each entry of `out_grad` corresponds to a window that consists of
|
|
||||||
// original input entries only).
|
|
||||||
int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
|
|
||||||
std::multiplies<int64_t>());
|
|
||||||
// Divide `out_grad` by window counts.
|
|
||||||
Value divisor =
|
|
||||||
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
|
|
||||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
|
||||||
out_grad_divided = rewriter.create<chlo::BroadcastDivOp>(
|
|
||||||
loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims);
|
|
||||||
} else {
|
|
||||||
assert(op.padding() == "SAME");
|
|
||||||
// For SAME padding, only original entries that contributed to a window
|
|
||||||
// are counted for the average of this window, not padded entries.
|
|
||||||
|
|
||||||
// Build all-ones tensor of same shape as the original input.
|
|
||||||
ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
|
|
||||||
auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
|
|
||||||
|
|
||||||
// Get the same padding as for the original input.
|
|
||||||
DenseIntElementsAttr orig_padding_attr =
|
|
||||||
GetReduceWindowPaddingAsAttr<num_dims>(orig_input_shape, op.ksize(),
|
|
||||||
op.strides(), op.padding(),
|
|
||||||
&rewriter);
|
|
||||||
|
|
||||||
// Count the 1's in each window, using the same padding as for the
|
|
||||||
// original input, which gives us the window counts by which `out_grad`
|
|
||||||
// needs to be divided.
|
|
||||||
auto window_counts = rewriter.create<ReduceWindowOp>(
|
|
||||||
loc, out_grad_type,
|
|
||||||
/*operand=*/all_ones_tensor,
|
|
||||||
/*init_value=*/zero,
|
|
||||||
/*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
|
|
||||||
/*window_strides=*/GetI64ElementsAttr(op.strides()),
|
|
||||||
/*base_dilations=*/DenseIntElementsAttr(),
|
|
||||||
/*window_dilations=*/DenseIntElementsAttr(),
|
|
||||||
/*padding=*/orig_padding_attr);
|
|
||||||
BuildReduceBody<AddOp>(element_type, &window_counts.body(), &rewriter);
|
|
||||||
|
|
||||||
// Divide `out_grad` by window counts.
|
|
||||||
out_grad_divided = rewriter.create<mhlo::DivOp>(loc, out_grad_type,
|
|
||||||
out_grad, window_counts);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get same padding as for original input.
|
// Get same padding as for original input.
|
||||||
PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
|
PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
|
||||||
@ -5506,10 +5524,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
|||||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||||
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
|
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
|
||||||
ConvertAvgPoolOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp,
|
ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp,
|
||||||
ConvertMaxPool2DOp, ConvertMaxPool3DOp, ConvertMaxPool2DGradOp,
|
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
|
||||||
ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp,
|
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
|
||||||
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
|
||||||
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
|
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||||
|
@ -936,7 +936,7 @@ tf_xla_py_test(
|
|||||||
srcs = ["pooling_ops_test.py"],
|
srcs = ["pooling_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
],
|
],
|
||||||
@ -956,7 +956,7 @@ tf_xla_py_test(
|
|||||||
srcs = ["pooling_ops_3d_test.py"],
|
srcs = ["pooling_ops_3d_test.py"],
|
||||||
enable_mlir_bridge = True,
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
],
|
],
|
||||||
|
@ -75,9 +75,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
|||||||
actual = vals.flatten()
|
actual = vals.flatten()
|
||||||
self.assertAllClose(expected, actual)
|
self.assertAllClose(expected, actual)
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
|
|
||||||
" doesn't support all paddings and data "
|
|
||||||
"formats")
|
|
||||||
def testAvgPool3dValidPadding(self):
|
def testAvgPool3dValidPadding(self):
|
||||||
expected_output = [20.5, 21.5, 22.5]
|
expected_output = [20.5, 21.5, 22.5]
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
@ -88,9 +85,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
|||||||
padding="VALID",
|
padding="VALID",
|
||||||
expected=expected_output)
|
expected=expected_output)
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
|
|
||||||
" doesn't support all paddings and data "
|
|
||||||
"formats")
|
|
||||||
def testAvgPool3dSamePadding(self):
|
def testAvgPool3dSamePadding(self):
|
||||||
expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5]
|
expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5]
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
@ -101,9 +95,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
|||||||
padding="SAME",
|
padding="SAME",
|
||||||
expected=expected_output)
|
expected=expected_output)
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
|
|
||||||
" doesn't support all paddings and data "
|
|
||||||
"formats")
|
|
||||||
def testAvgPool3dSamePaddingDifferentStrides(self):
|
def testAvgPool3dSamePaddingDifferentStrides(self):
|
||||||
expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5]
|
expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5]
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
|
@ -268,9 +268,6 @@ class PoolingTest(xla_test.XLATestCase):
|
|||||||
expected=[1, 3, 9, 11])
|
expected=[1, 3, 9, 11])
|
||||||
|
|
||||||
# Average pooling
|
# Average pooling
|
||||||
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
|
|
||||||
" doesn't support all paddings and data "
|
|
||||||
"formats")
|
|
||||||
def testAvgPoolValidPadding(self):
|
def testAvgPoolValidPadding(self):
|
||||||
expected_output = [7, 8, 9]
|
expected_output = [7, 8, 9]
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
@ -281,9 +278,6 @@ class PoolingTest(xla_test.XLATestCase):
|
|||||||
padding="VALID",
|
padding="VALID",
|
||||||
expected=expected_output)
|
expected=expected_output)
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
|
|
||||||
" doesn't support all paddings and data "
|
|
||||||
"formats")
|
|
||||||
def testAvgPoolSamePadding(self):
|
def testAvgPoolSamePadding(self):
|
||||||
expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
|
expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
|
Loading…
Reference in New Issue
Block a user