Support SAME padding, non-default data formats and AvgPool3D in AvgPool lowering

Also added unit tests and enabled all related functional tests for pooling ops.

PiperOrigin-RevId: 321433980
Change-Id: Ia175333d81398deadfcc9e18bb7eeabe8e0c67de
This commit is contained in:
Michael Gester 2020-07-15 14:02:00 -07:00 committed by TensorFlower Gardener
parent 806a053eb5
commit 822b035fd1
6 changed files with 278 additions and 122 deletions

View File

@ -745,6 +745,30 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
} }
def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> {
let summary = "Performs 3D average pooling on the input.";
let description = [{
Each entry in `output` is the mean of the corresponding size `ksize`
window in `value`.
}];
let arguments = (ins
TF_FpTensor:$value,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
let summary = "Computes gradients of the average pooling function."; let summary = "Computes gradients of the average pooling function.";

View File

@ -980,7 +980,7 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
// SparseMatMul where one operand needs to be transposed and the other one not. // SparseMatMul where one operand needs to be transposed and the other one not.
// //
// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose // CHECK-LABEL: @test_sparse_mat_mul_with_transpose
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
// CHECK-SAME: -> tensor<3x5xf32> // CHECK-SAME: -> tensor<3x5xf32>
@ -990,7 +990,6 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]]) // CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
// CHECK-SAME: -> tensor<3x5xf32> // CHECK-SAME: -> tensor<3x5xf32>
// CHECK: return %[[RESULT]] // CHECK: return %[[RESULT]]
// CHECK: }
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
return %0: tensor<3x5xf32> return %0: tensor<3x5xf32>
@ -998,7 +997,7 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5
// SparseMatMul where one operand needs to be casted and the other one not. // SparseMatMul where one operand needs to be casted and the other one not.
// //
// CHECK-LABEL: func @test_sparse_mat_mul_with_cast // CHECK-LABEL: @test_sparse_mat_mul_with_cast
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> // CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
// CHECK-SAME: -> tensor<3x5xf32> // CHECK-SAME: -> tensor<3x5xf32>
@ -1007,7 +1006,6 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]]) // CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
// CHECK-SAME: -> tensor<3x5xf32> // CHECK-SAME: -> tensor<3x5xf32>
// CHECK: return %[[RESULT]] // CHECK: return %[[RESULT]]
// CHECK: }
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
return %0: tensor<3x5xf32> return %0: tensor<3x5xf32>
@ -3868,36 +3866,167 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
// tf.AvgPool legalization // tf.AvgPool legalization
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CHECK-LABEL: avgpool_valid_padding // CHECK-LABEL: @avgpool_valid_padding
// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x20x7xf16> // CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16>
func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> { // CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32>
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x20x7xf16>) -> tensor<2x12x20x7xf32> // CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: [[INIT:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
// CHECK: [[REDUCE:%.+]] = "mhlo.reduce_window"([[CONV32]], [[INIT]]) ( { // CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] // CHECK: "mhlo.return"([[ADD]])
// CHECK: "mhlo.return"([[ADD]]) // CHECK: })
// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32> // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]>
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]>
// CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32> // CHECK-SAME: -> tensor<2x3x5x7xf32>
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: return [[CONV16]] // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK-SAME: broadcast_dimensions = dense<>
// CHECK-SAME: -> tensor<2x3x5x7xf32>
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
// CHECK-SAME: -> tensor<2x3x5x7xf16>
// CHECK: return [[CONV16]]
func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> {
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16>
return %0 : tensor<2x3x5x7xf16> return %0 : tensor<2x3x5x7xf16>
} }
// CHECK-LABEL: avgpool_same_padding // CHECK-LABEL: @avgpool_3d_valid_padding
func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> { // CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16>
// CHECK: tf.AvgPool // CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32>
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> // CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
return %0 : tensor<2x4x7x7xf32> // CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
// CHECK: "mhlo.return"([[ADD]])
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]>
// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]>
// CHECK-SAME: -> tensor<2x4x3x5x7xf32>
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
// CHECK-SAME: broadcast_dimensions = dense<>
// CHECK-SAME: -> tensor<2x4x3x5x7xf32>
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
// CHECK-SAME: -> tensor<2x4x3x5x7xf16>
// CHECK: return [[CONV16]]
func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> {
%0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16>
return %0 : tensor<2x4x3x5x7xf16>
}
// CHECK-LABEL: @avgpool_nchw_format
// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16>
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32>
// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
// CHECK: "mhlo.return"([[ADD]])
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]>
// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]>
// CHECK-SAME: -> tensor<2x7x3x5xf32>
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
// CHECK-SAME: broadcast_dimensions = dense<>
// CHECK-SAME: -> tensor<2x7x3x5xf32>
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
// CHECK-SAME: -> tensor<2x7x3x5xf16>
// CHECK: return [[CONV16]]
func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> {
%0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16>
return %0 : tensor<2x7x3x5xf16>
}
// CHECK-LABEL: @avgpool_3d_ncdhw_format
// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16>
// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32>
// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
// CHECK: "mhlo.return"([[ADD]])
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]>
// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]>
// CHECK-SAME: -> tensor<2x7x4x3x5xf32>
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
// CHECK-SAME: broadcast_dimensions = dense<>
// CHECK-SAME: -> tensor<2x7x4x3x5xf32>
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
// CHECK-SAME: -> tensor<2x7x4x3x5xf16>
// CHECK: return [[CONV16]]
func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> {
%0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16>
return %0 : tensor<2x7x4x3x5xf16>
}
// CHECK-LABEL: @avgpool_same_padding(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]>
// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]>
// CHECK-SAME: -> tensor<2x4x6x7xf32>
// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32>
// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]>
// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]>
// CHECK-SAME: -> tensor<2x4x6x7xf32>
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32>
// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32>
// CHECK: }
func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> {
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
return %0 : tensor<2x4x6x7xf32>
}
// CHECK-LABEL: @avgpool_3d_same_padding(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]>
// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]>
// CHECK-SAME: -> tensor<2x4x4x6x7xf32>
// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32>
// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]>
// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]>
// CHECK-SAME: -> tensor<2x4x4x6x7xf32>
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]]
// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32>
// CHECK: }
func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> {
%0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
return %0 : tensor<2x4x4x6x7xf32>
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AvgPoolGrad op legalizations. // AvgPoolGrad op legalizations.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CHECK-LABEL: func @avgpool_grad_valid_padding( // CHECK-LABEL: @avgpool_grad_valid_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
@ -3929,7 +4058,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24
return %result : tensor<10x24x32x64xf32> return %result : tensor<10x24x32x64xf32>
} }
// CHECK-LABEL: func @avgpool_3d_grad_valid_padding( // CHECK-LABEL: @avgpool_3d_grad_valid_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
@ -3958,7 +4087,7 @@ func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<
return %result : tensor<10x8x24x32x64xf32> return %result : tensor<10x8x24x32x64xf32>
} }
// CHECK-LABEL: func @avgpool_grad_same_padding( // CHECK-LABEL: @avgpool_grad_same_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
@ -3997,7 +4126,7 @@ func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9x
return %result : tensor<2x13x25x9xf32> return %result : tensor<2x13x25x9xf32>
} }
// CHECK-LABEL: func @avgpool_3d_grad_same_padding( // CHECK-LABEL: @avgpool_3d_grad_same_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
@ -4035,7 +4164,7 @@ func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x1
return %result : tensor<2x8x13x25x9xf32> return %result : tensor<2x8x13x25x9xf32>
} }
// CHECK-LABEL: func @avgpool_grad_nchw_format( // CHECK-LABEL: @avgpool_grad_nchw_format(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
@ -4074,7 +4203,7 @@ func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf
return %result : tensor<2x9x13x25xf32> return %result : tensor<2x9x13x25xf32>
} }
// CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format( // CHECK-LABEL: @avgpool_3d_grad_ncdwh_format(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
@ -4112,7 +4241,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8
return %result : tensor<2x9x8x13x25xf32> return %result : tensor<2x9x8x13x25xf32>
} }
// CHECK-LABEL: func @avgpool_grad_bf16( // CHECK-LABEL: @avgpool_grad_bf16(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16> // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>

View File

@ -1770,30 +1770,93 @@ static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
flatten_paddings); flatten_paddings);
} }
// Helper function for dividing each entry of `pooled` by the count of its
// corresponding window, i.e., the number of non-padding entries of the window
// which an `AvgPool` operation performed on an `input_shape`-tensor would map
// to this entry, depending on `ksize` and `strides`. This function is used for
// `AvgPool` and `AvgPoolGrad` legalizations.
// `zero` is passed as a parameter because it can be reused from caller level.
// `pooled` must have `RankedTensorType`.
template <typename OpTy, int num_dims>
Operation *AvgPoolDivideByCount(
Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
const SmallVector<int64_t, num_dims> &ksize,
const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
RankedTensorType pooled_type =
pooled.getType().template cast<RankedTensorType>();
Type element_type = pooled_type.getElementType();
Operation *result = nullptr;
RankedTensorType orig_input_type =
RankedTensorType::get(input_shape, element_type);
if (op.padding() == "VALID") {
// All window counts are equal here because we don't have padding
// (each entry of `pooled` corresponds to a window that consists of
// original input entries only).
int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
std::multiplies<int64_t>());
// Divide `pooled` by window counts.
Value divisor =
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
result = rewriter.create<chlo::BroadcastDivOp>(
loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
} else {
assert(op.padding() == "SAME");
// For SAME padding, only original entries that contributed to a window
// are counted for the average of this window, not padded entries.
// Build all-ones tensor of same shape as the original input.
ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
// Get padding for the input.
DenseIntElementsAttr input_padding_attr =
GetReduceWindowPaddingAsAttr<num_dims>(
input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
// Count the 1's in each window, using the same padding as for the input,
// which gives us the window counts by which `pooled` needs to be divided.
auto divisor = rewriter.create<ReduceWindowOp>(
loc, pooled_type,
/*operand=*/all_ones_tensor,
/*init_value=*/zero,
/*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
/*window_strides=*/GetI64ElementsAttr(op.strides()),
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(),
/*padding=*/input_padding_attr);
BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
// Divide `pooled` by window counts.
result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled, divisor);
}
return result;
}
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
// dimensions with add as the reduction function. The reduction result is // dimensions with add as the reduction function. The reduction result is
// then divided by the number of elements in the window. // then divided by the number of elements in the window.
class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> { template <typename OpTy, int num_dims>
class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::AvgPoolOp op, LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto input_type = op.value().getType().dyn_cast<RankedTensorType>(); auto input_type =
op.value().getType().template dyn_cast<RankedTensorType>();
if (!input_type) return failure(); if (!input_type) return failure();
// TODO(b/147217034): support other data formats.
if (!IsDefaultDataFormat(op.data_format())) return failure();
// TODO(b/147217034): support "SAME" padding.
if (op.padding() != "VALID") return failure();
// We will do accumulation first; use a larger bitwidth if suitable. // We will do accumulation first; use a larger bitwidth if suitable.
Type input_element_type = input_type.getElementType(); Type input_element_type = input_type.getElementType();
Type sum_element_type = GetSumAccumulationType(input_element_type); Type sum_element_type = GetSumAccumulationType(input_element_type);
Type result_type; Type result_type;
// The result type for reduction and division with the proper element type. // The result type for reduction and division with the proper element type.
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
result_type = result_type =
RankedTensorType::get(ranked_type.getShape(), sum_element_type); RankedTensorType::get(ranked_type.getShape(), sum_element_type);
else else
@ -1809,9 +1872,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
// Create the tf.ReduceWindow op. // Create the tf.ReduceWindow op.
Value init = Value init =
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
DenseIntElementsAttr paddings_attr = DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
GetReduceWindowPaddingAsAttr<4>(input_type.getShape(), op.ksize(), input_type.getShape(), op.ksize(), op.strides(), op.padding(),
op.strides(), op.padding(), &rewriter); &rewriter);
auto reduce = rewriter.create<ReduceWindowOp>( auto reduce = rewriter.create<ReduceWindowOp>(
op.getLoc(), result_type, input_value, init, op.getLoc(), result_type, input_value, init,
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
@ -1821,19 +1884,17 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
// Count the number of elements in the window. The following calculation // Count the number of elements in the window. The following calculation
// is only valid for no paddings. // is only valid for no paddings.
SmallVector<int64_t, 4> ksize; SmallVector<int64_t, num_dims> input_shape(
llvm::to_vector<num_dims>(input_type.getShape()));
SmallVector<int64_t, num_dims> ksize, strides;
GetI64ArrayAttrValues(op.ksize(), &ksize); GetI64ArrayAttrValues(op.ksize(), &ksize);
int64_t count = std::accumulate(ksize.begin(), ksize.end(), 1, GetI64ArrayAttrValues(op.strides(), &strides);
std::multiplies<int64_t>());
// Divide by the number of elements in the window. Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
Value divisor = reduce.getResult(), input_shape, ksize, strides, op, init, rewriter);
GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
Value result = rewriter.create<chlo::BroadcastDivOp>(
op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims);
// Convert back if we enlarged the element type's bitwidth. // Convert back if we enlarged the element type's bitwidth.
Value result = result_op->getOpResult(0);
if (input_element_type != sum_element_type) if (input_element_type != sum_element_type)
result = result =
rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type); rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
@ -1843,6 +1904,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
} }
}; };
using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
// `AvgPoolGradOp` is converted to the following operations: // `AvgPoolGradOp` is converted to the following operations:
// 1. Divide each entry of the output gradient (the gradient for the previous // 1. Divide each entry of the output gradient (the gradient for the previous
// layer in backpropagation order) by the count of the corresponding window // layer in backpropagation order) by the count of the corresponding window
@ -1916,59 +1980,13 @@ class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>(); auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
DimVector orig_input_shape(orig_input_shape_values.begin(), DimVector orig_input_shape(orig_input_shape_values.begin(),
orig_input_shape_values.end()); orig_input_shape_values.end());
RankedTensorType orig_input_type =
RankedTensorType::get(orig_input_shape, element_type);
DimVector ksize, strides; DimVector ksize, strides;
GetI64ArrayAttrValues(op.ksize(), &ksize); GetI64ArrayAttrValues(op.ksize(), &ksize);
GetI64ArrayAttrValues(op.strides(), &strides); GetI64ArrayAttrValues(op.strides(), &strides);
Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
Operation *out_grad_divided = nullptr; auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
if (op.padding() == "VALID") { out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
// All window counts are equal here because we don't have padding
// (each entry of `out_grad` corresponds to a window that consists of
// original input entries only).
int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
std::multiplies<int64_t>());
// Divide `out_grad` by window counts.
Value divisor =
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
out_grad_divided = rewriter.create<chlo::BroadcastDivOp>(
loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims);
} else {
assert(op.padding() == "SAME");
// For SAME padding, only original entries that contributed to a window
// are counted for the average of this window, not padded entries.
// Build all-ones tensor of same shape as the original input.
ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
// Get the same padding as for the original input.
DenseIntElementsAttr orig_padding_attr =
GetReduceWindowPaddingAsAttr<num_dims>(orig_input_shape, op.ksize(),
op.strides(), op.padding(),
&rewriter);
// Count the 1's in each window, using the same padding as for the
// original input, which gives us the window counts by which `out_grad`
// needs to be divided.
auto window_counts = rewriter.create<ReduceWindowOp>(
loc, out_grad_type,
/*operand=*/all_ones_tensor,
/*init_value=*/zero,
/*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
/*window_strides=*/GetI64ElementsAttr(op.strides()),
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(),
/*padding=*/orig_padding_attr);
BuildReduceBody<AddOp>(element_type, &window_counts.body(), &rewriter);
// Divide `out_grad` by window counts.
out_grad_divided = rewriter.create<mhlo::DivOp>(loc, out_grad_type,
out_grad, window_counts);
}
// Get same padding as for original input. // Get same padding as for original input.
PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>( PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
@ -5506,10 +5524,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
ConvertAvgPoolOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp,
ConvertMaxPool2DOp, ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op, ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,

View File

@ -936,7 +936,7 @@ tf_xla_py_test(
srcs = ["pooling_ops_test.py"], srcs = ["pooling_ops_test.py"],
enable_mlir_bridge = True, enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 10, shard_count = 20,
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
], ],
@ -956,7 +956,7 @@ tf_xla_py_test(
srcs = ["pooling_ops_3d_test.py"], srcs = ["pooling_ops_3d_test.py"],
enable_mlir_bridge = True, enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 10, shard_count = 20,
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
], ],

View File

@ -75,9 +75,6 @@ class Pooling3DTest(xla_test.XLATestCase):
actual = vals.flatten() actual = vals.flatten()
self.assertAllClose(expected, actual) self.assertAllClose(expected, actual)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPool3dValidPadding(self): def testAvgPool3dValidPadding(self):
expected_output = [20.5, 21.5, 22.5] expected_output = [20.5, 21.5, 22.5]
self._VerifyValues( self._VerifyValues(
@ -88,9 +85,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="VALID", padding="VALID",
expected=expected_output) expected=expected_output)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPool3dSamePadding(self): def testAvgPool3dSamePadding(self):
expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5] expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5]
self._VerifyValues( self._VerifyValues(
@ -101,9 +95,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME", padding="SAME",
expected=expected_output) expected=expected_output)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPool3dSamePaddingDifferentStrides(self): def testAvgPool3dSamePaddingDifferentStrides(self):
expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5] expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5]
self._VerifyValues( self._VerifyValues(

View File

@ -268,9 +268,6 @@ class PoolingTest(xla_test.XLATestCase):
expected=[1, 3, 9, 11]) expected=[1, 3, 9, 11])
# Average pooling # Average pooling
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPoolValidPadding(self): def testAvgPoolValidPadding(self):
expected_output = [7, 8, 9] expected_output = [7, 8, 9]
self._VerifyValues( self._VerifyValues(
@ -281,9 +278,6 @@ class PoolingTest(xla_test.XLATestCase):
padding="VALID", padding="VALID",
expected=expected_output) expected=expected_output)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPoolSamePadding(self): def testAvgPoolSamePadding(self):
expected_output = [7., 8., 9., 11.5, 12.5, 13.5] expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
self._VerifyValues( self._VerifyValues(