diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 3509ebfba4b..863de28df51 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -745,6 +745,30 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> { + let summary = "Performs 3D average pooling on the input."; + + let description = [{ +Each entry in `output` is the mean of the corresponding size `ksize` +window in `value`. + }]; + + let arguments = (ins + TF_FpTensor:$value, + + Confined]>:$ksize, + Confined]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedAttr, "NDHWC">:$data_format + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { let summary = "Computes gradients of the average pooling function."; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 28a31058e99..7e897c36fe7 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -980,7 +980,7 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten // SparseMatMul where one operand needs to be transposed and the other one not. // -// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose +// CHECK-LABEL: @test_sparse_mat_mul_with_transpose // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> // CHECK-SAME: -> tensor<3x5xf32> @@ -990,7 +990,6 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten // CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]]) // CHECK-SAME: -> tensor<3x5xf32> // CHECK: return %[[RESULT]] -// CHECK: } func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> return %0: tensor<3x5xf32> @@ -998,7 +997,7 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5 // SparseMatMul where one operand needs to be casted and the other one not. // -// CHECK-LABEL: func @test_sparse_mat_mul_with_cast +// CHECK-LABEL: @test_sparse_mat_mul_with_cast // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> // CHECK-SAME: -> tensor<3x5xf32> @@ -1007,7 +1006,6 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5 // CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]]) // CHECK-SAME: -> tensor<3x5xf32> // CHECK: return %[[RESULT]] -// CHECK: } func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> return %0: tensor<3x5xf32> @@ -3868,36 +3866,167 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // tf.AvgPool legalization //===----------------------------------------------------------------------===// -// CHECK-LABEL: avgpool_valid_padding -// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x20x7xf16> -func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> { - // CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x20x7xf16>) -> tensor<2x12x20x7xf32> - // CHECK: [[INIT:%.+]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.+]] = "mhlo.reduce_window"([[CONV32]], [[INIT]]) ( { - // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): - // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] - // CHECK: "mhlo.return"([[ADD]]) - // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> - // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> - // CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> - // CHECK: return [[CONV16]] - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> +// CHECK-LABEL: @avgpool_valid_padding +// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> +// CHECK-SAME: -> tensor<2x3x5x7xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x3x5x7xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x3x5x7xf16> +// CHECK: return [[CONV16]] +func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> return %0 : tensor<2x3x5x7xf16> } -// CHECK-LABEL: avgpool_same_padding -func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> { - // CHECK: tf.AvgPool - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> - return %0 : tensor<2x4x7x7xf32> +// CHECK-LABEL: @avgpool_3d_valid_padding +// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> +// CHECK-SAME: -> tensor<2x4x3x5x7xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x4x3x5x7xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x4x3x5x7xf16> +// CHECK: return [[CONV16]] +func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> + return %0 : tensor<2x4x3x5x7xf16> +} + +// CHECK-LABEL: @avgpool_nchw_format +// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> +// CHECK-SAME: -> tensor<2x7x3x5xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x7x3x5xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x7x3x5xf16> +// CHECK: return [[CONV16]] +func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> + return %0 : tensor<2x7x3x5xf16> +} + +// CHECK-LABEL: @avgpool_3d_ncdhw_format +// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> +// CHECK-SAME: -> tensor<2x7x4x3x5xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x7x4x3x5xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x7x4x3x5xf16> +// CHECK: return [[CONV16]] +func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> + return %0 : tensor<2x7x4x3x5xf16> +} + +// CHECK-LABEL: @avgpool_same_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> +// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> +// CHECK: } +func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> + return %0 : tensor<2x4x6x7xf32> +} + +// CHECK-LABEL: @avgpool_3d_same_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] +// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> +// CHECK: } +func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> + return %0 : tensor<2x4x4x6x7xf32> } //===----------------------------------------------------------------------===// // AvgPoolGrad op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @avgpool_grad_valid_padding( +// CHECK-LABEL: @avgpool_grad_valid_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor @@ -3929,7 +4058,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24 return %result : tensor<10x24x32x64xf32> } -// CHECK-LABEL: func @avgpool_3d_grad_valid_padding( +// CHECK-LABEL: @avgpool_3d_grad_valid_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor @@ -3958,7 +4087,7 @@ func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor< return %result : tensor<10x8x24x32x64xf32> } -// CHECK-LABEL: func @avgpool_grad_same_padding( +// CHECK-LABEL: @avgpool_grad_same_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> @@ -3997,7 +4126,7 @@ func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9x return %result : tensor<2x13x25x9xf32> } -// CHECK-LABEL: func @avgpool_3d_grad_same_padding( +// CHECK-LABEL: @avgpool_3d_grad_same_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> @@ -4035,7 +4164,7 @@ func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x1 return %result : tensor<2x8x13x25x9xf32> } -// CHECK-LABEL: func @avgpool_grad_nchw_format( +// CHECK-LABEL: @avgpool_grad_nchw_format( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> @@ -4074,7 +4203,7 @@ func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf return %result : tensor<2x9x13x25xf32> } -// CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format( +// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> @@ -4112,7 +4241,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8 return %result : tensor<2x9x8x13x25xf32> } -// CHECK-LABEL: func @avgpool_grad_bf16( +// CHECK-LABEL: @avgpool_grad_bf16( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index af176dc1324..4549386ce16 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1770,30 +1770,93 @@ static DenseIntElementsAttr GetReduceWindowPaddingAsAttr( flatten_paddings); } +// Helper function for dividing each entry of `pooled` by the count of its +// corresponding window, i.e., the number of non-padding entries of the window +// which an `AvgPool` operation performed on an `input_shape`-tensor would map +// to this entry, depending on `ksize` and `strides`. This function is used for +// `AvgPool` and `AvgPoolGrad` legalizations. +// `zero` is passed as a parameter because it can be reused from caller level. +// `pooled` must have `RankedTensorType`. +template +Operation *AvgPoolDivideByCount( + Value pooled, const SmallVector &input_shape, + const SmallVector &ksize, + const SmallVector &strides, OpTy op, Value zero, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + RankedTensorType pooled_type = + pooled.getType().template cast(); + Type element_type = pooled_type.getElementType(); + Operation *result = nullptr; + RankedTensorType orig_input_type = + RankedTensorType::get(input_shape, element_type); + + if (op.padding() == "VALID") { + // All window counts are equal here because we don't have padding + // (each entry of `pooled` corresponds to a window that consists of + // original input entries only). + int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, + std::multiplies()); + // Divide `pooled` by window counts. + Value divisor = + GetScalarConstOfType(element_type, loc, window_count, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + result = rewriter.create( + loc, pooled_type, pooled, divisor, scalar_broadcast_dims); + } else { + assert(op.padding() == "SAME"); + // For SAME padding, only original entries that contributed to a window + // are counted for the average of this window, not padded entries. + + // Build all-ones tensor of same shape as the original input. + ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); + auto all_ones_tensor = rewriter.create(loc, splat); + + // Get padding for the input. + DenseIntElementsAttr input_padding_attr = + GetReduceWindowPaddingAsAttr( + input_shape, op.ksize(), op.strides(), op.padding(), &rewriter); + + // Count the 1's in each window, using the same padding as for the input, + // which gives us the window counts by which `pooled` needs to be divided. + auto divisor = rewriter.create( + loc, pooled_type, + /*operand=*/all_ones_tensor, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), + /*window_strides=*/GetI64ElementsAttr(op.strides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/input_padding_attr); + BuildReduceBody(element_type, &divisor.body(), &rewriter); + + // Divide `pooled` by window counts. + result = rewriter.create(loc, pooled_type, pooled, divisor); + } + return result; +} + // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // dimensions with add as the reduction function. The reduction result is // then divided by the number of elements in the window. -class ConvertAvgPoolOp : public OpRewritePattern { +template +class ConvertAvgPoolOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::AvgPoolOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto input_type = op.value().getType().dyn_cast(); + auto input_type = + op.value().getType().template dyn_cast(); if (!input_type) return failure(); - // TODO(b/147217034): support other data formats. - if (!IsDefaultDataFormat(op.data_format())) return failure(); - // TODO(b/147217034): support "SAME" padding. - if (op.padding() != "VALID") return failure(); - // We will do accumulation first; use a larger bitwidth if suitable. Type input_element_type = input_type.getElementType(); Type sum_element_type = GetSumAccumulationType(input_element_type); Type result_type; // The result type for reduction and division with the proper element type. - if (auto ranked_type = op.getType().dyn_cast()) + if (auto ranked_type = op.getType().template dyn_cast()) result_type = RankedTensorType::get(ranked_type.getShape(), sum_element_type); else @@ -1809,9 +1872,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Create the tf.ReduceWindow op. Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); - DenseIntElementsAttr paddings_attr = - GetReduceWindowPaddingAsAttr<4>(input_type.getShape(), op.ksize(), - op.strides(), op.padding(), &rewriter); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_type.getShape(), op.ksize(), op.strides(), op.padding(), + &rewriter); auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), @@ -1821,19 +1884,17 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Count the number of elements in the window. The following calculation // is only valid for no paddings. - SmallVector ksize; + SmallVector input_shape( + llvm::to_vector(input_type.getShape())); + SmallVector ksize, strides; GetI64ArrayAttrValues(op.ksize(), &ksize); - int64_t count = std::accumulate(ksize.begin(), ksize.end(), 1, - std::multiplies()); + GetI64ArrayAttrValues(op.strides(), &strides); - // Divide by the number of elements in the window. - Value divisor = - GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - Value result = rewriter.create( - op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); + Operation *result_op = AvgPoolDivideByCount( + reduce.getResult(), input_shape, ksize, strides, op, init, rewriter); // Convert back if we enlarged the element type's bitwidth. + Value result = result_op->getOpResult(0); if (input_element_type != sum_element_type) result = rewriter.create(op.getLoc(), result, input_element_type); @@ -1843,6 +1904,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { } }; +using ConvertAvgPool2DOp = ConvertAvgPoolOp; +using ConvertAvgPool3DOp = ConvertAvgPoolOp; + // `AvgPoolGradOp` is converted to the following operations: // 1. Divide each entry of the output gradient (the gradient for the previous // layer in backpropagation order) by the count of the corresponding window @@ -1916,59 +1980,13 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { auto orig_input_shape_values = orig_input_shape_attr.getValues(); DimVector orig_input_shape(orig_input_shape_values.begin(), orig_input_shape_values.end()); - RankedTensorType orig_input_type = - RankedTensorType::get(orig_input_shape, element_type); DimVector ksize, strides; GetI64ArrayAttrValues(op.ksize(), &ksize); GetI64ArrayAttrValues(op.strides(), &strides); Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); - Operation *out_grad_divided = nullptr; - if (op.padding() == "VALID") { - // All window counts are equal here because we don't have padding - // (each entry of `out_grad` corresponds to a window that consists of - // original input entries only). - int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, - std::multiplies()); - // Divide `out_grad` by window counts. - Value divisor = - GetScalarConstOfType(element_type, loc, window_count, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - out_grad_divided = rewriter.create( - loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims); - } else { - assert(op.padding() == "SAME"); - // For SAME padding, only original entries that contributed to a window - // are counted for the average of this window, not padded entries. - - // Build all-ones tensor of same shape as the original input. - ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); - auto all_ones_tensor = rewriter.create(loc, splat); - - // Get the same padding as for the original input. - DenseIntElementsAttr orig_padding_attr = - GetReduceWindowPaddingAsAttr(orig_input_shape, op.ksize(), - op.strides(), op.padding(), - &rewriter); - - // Count the 1's in each window, using the same padding as for the - // original input, which gives us the window counts by which `out_grad` - // needs to be divided. - auto window_counts = rewriter.create( - loc, out_grad_type, - /*operand=*/all_ones_tensor, - /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), - /*window_strides=*/GetI64ElementsAttr(op.strides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), - /*padding=*/orig_padding_attr); - BuildReduceBody(element_type, &window_counts.body(), &rewriter); - - // Divide `out_grad` by window counts. - out_grad_divided = rewriter.create(loc, out_grad_type, - out_grad, window_counts); - } + auto out_grad_divided = AvgPoolDivideByCount( + out_grad, orig_input_shape, ksize, strides, op, zero, rewriter); // Get same padding as for original input. PaddingArray orig_padding = GetReduceWindowPaddingAsArray( @@ -5506,10 +5524,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, - ConvertAvgPoolOp, ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, - ConvertMaxPool2DOp, ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, - ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp, - ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, + ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, + ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, + ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSoftmaxOp, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 0b5a6c147dc..1dd61c235a8 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -936,7 +936,7 @@ tf_xla_py_test( srcs = ["pooling_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 10, + shard_count = 20, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -956,7 +956,7 @@ tf_xla_py_test( srcs = ["pooling_ops_3d_test.py"], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 10, + shard_count = 20, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index a833daa39be..9eda74b55a9 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -75,9 +75,6 @@ class Pooling3DTest(xla_test.XLATestCase): actual = vals.flatten() self.assertAllClose(expected, actual) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPool3dValidPadding(self): expected_output = [20.5, 21.5, 22.5] self._VerifyValues( @@ -88,9 +85,6 @@ class Pooling3DTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPool3dSamePadding(self): expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5] self._VerifyValues( @@ -101,9 +95,6 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", expected=expected_output) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPool3dSamePaddingDifferentStrides(self): expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5] self._VerifyValues( diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 293e1010b08..d9393387c0d 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -268,9 +268,6 @@ class PoolingTest(xla_test.XLATestCase): expected=[1, 3, 9, 11]) # Average pooling - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPoolValidPadding(self): expected_output = [7, 8, 9] self._VerifyValues( @@ -281,9 +278,6 @@ class PoolingTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPoolSamePadding(self): expected_output = [7., 8., 9., 11.5, 12.5, 13.5] self._VerifyValues(