diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 65ca3ea4dbd..146af4adbbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -745,6 +745,29 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> { + let summary = "Computes gradients of the 3D average pooling function."; + + let description = [{ + }]; + + let arguments = (ins + I32Tensor:$orig_input_shape, + TF_FpTensor:$grad, + + Confined]>:$ksize, + Confined]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedAttr, "NDHWC">:$data_format + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> { let summary = "Multiplies slices of two tensors in batches."; diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index d089f80d571..657f34ee02e 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -175,6 +175,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:sharding_builder", + "//tensorflow/compiler/xla/client/lib:conv_grad_size_util", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/core/lib/bfloat16", diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 2cd98ea3f6b..10d69221979 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1096,7 +1096,6 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> - // CHECK: } %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { data_format = "NHWC", ksize = [1, 2, 2, 1], @@ -1120,7 +1119,6 @@ func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_outpu // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> - // CHECK: } %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> return %result : tensor<10x8x24x24x64xf32> } @@ -3838,6 +3836,260 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> return %0 : tensor<2x4x7x7xf32> } +//===----------------------------------------------------------------------===// +// AvgPoolGrad op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @avgpool_grad_valid_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK_SAME: broadcast_dimensions = dense<[]> +// CHECK_SAME: -> tensor<10x12x16x64xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> +// CHECK-SAME: -> tensor<10x25x33x64xf32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK-SAME: -> tensor<10x24x32x64xf32> +// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> +func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> + return %result : tensor<10x24x32x64xf32> +} + +// CHECK-LABEL: func @avgpool_3d_grad_valid_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: -> tensor<10x8x25x33x64xf32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK-SAME: -> tensor<10x8x24x32x64xf32> +// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> +func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { + data_format = "NDHWC", + ksize = [1, 1, 2, 2, 1], + padding = "VALID", + strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> + return %result : tensor<10x8x24x32x64xf32> +} + +// CHECK-LABEL: func @avgpool_grad_same_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> +// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> +// CHECK-SAME: -> tensor<2x4x7x9xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> +// CHECK-SAME: -> tensor<2x14x27x9xf32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK-SAME: -> tensor<2x13x25x9xf32> +// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> +func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NHWC", + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> + return %result : tensor<2x13x25x9xf32> +} + +// CHECK-LABEL: func @avgpool_3d_grad_same_padding( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> +// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> +// CHECK-SAME: -> tensor<2x8x4x7x9xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> +// CHECK-SAME: -> tensor<2x8x14x27x9xf32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK-SAME: -> tensor<2x8x13x25x9xf32> +// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> +func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { + data_format = "NDHWC", + ksize = [1, 1, 2, 3, 1], + padding = "SAME", + strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> + return %result : tensor<2x8x13x25x9xf32> +} + +// CHECK-LABEL: func @avgpool_grad_nchw_format( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> +// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> +// CHECK-SAME: -> tensor<2x9x4x7xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> +// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> +// CHECK-SAME: -> tensor<2x9x14x27xf32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<1> +// CHECK-SAME: -> tensor<2x9x13x25xf32> +// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> +func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NCHW", + ksize = [1, 1, 2, 3], + padding = "SAME", + strides = [1, 1, 4, 4] + } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> + return %result : tensor<2x9x13x25xf32> +} + +// CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> +// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> +// CHECK-SAME: -> tensor<2x9x8x4x7xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> +// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> +// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> +// CHECK-SAME: -> tensor<2x9x8x14x27xf32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> +// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> +// CHECK-SAME: -> tensor<2x9x8x13x25xf32> +// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> +func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { + %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) + %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { + data_format = "NCDHW", + ksize = [1, 1, 1, 2, 3], + padding = "SAME", + strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> + return %result : tensor<2x9x8x13x25xf32> +} + +// CHECK-LABEL: func @avgpool_grad_bf16( +// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { +// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK-SAME: broadcast_dimensions = dense<[]> +// CHECK-SAME: -> tensor<10x12x16x64xbf16> +// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> +// CHECK-SAME: -> tensor<10x25x33x64xbf16> +// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "xla_hlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> +// CHECK: %[[ZERO_F32:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<1> +// CHECK-SAME: -> tensor<10x24x32x64xf32> +// CHECK: %[[RESULT_CONVERTED:.*]] = "xla_hlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> +// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> +func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { + %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) + %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> + return %result : tensor<10x24x32x64xbf16> +} + // CHECK-LABEL: xla_sharding func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { // CHECK-NEXT: "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, xla_hlo.sharding = ""} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 1788cd1b270..73aa555fb2d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" +#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -570,7 +571,7 @@ Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, dimension_numbers, precision_config); } -// Builds body for reduce op by using the using the template binary op as the +// Builds body for reduce op by using the template binary op as the // reducer op. template static void BuildReduceBody(Type element_type, Region *body, @@ -1697,17 +1698,21 @@ class ConvertFusedBatchNormV3Op } }; -// Returns padding attribute for ReduceWindow op with given params. +using PaddingArray = + std::vector>; + +// Returns padding values for ReduceWindow op as a vector of pairs. // // Requires padding to be either 'SAME' or 'VALID' and the number of input // dimensions to be equal to the size of window dimensions and window strides. template -static DenseIntElementsAttr GetReduceWindowPadding( +static PaddingArray GetReduceWindowPaddingAsArray( llvm::ArrayRef input_dims, ArrayAttr window_dims, ArrayAttr window_strides, StringRef padding, Builder *builder) { - if (padding == "VALID") return {}; - DCHECK_EQ(padding.str(), "SAME"); - + if (padding == "VALID") { + return PaddingArray(num_dims, std::make_pair(0, 0)); + } + assert(padding == "SAME"); llvm::SmallVector input_shape, window_shape, strides; input_shape.reserve(input_dims.size()); @@ -1720,9 +1725,21 @@ static DenseIntElementsAttr GetReduceWindowPadding( for (Attribute attr : window_strides) strides.push_back(attr.cast().getInt()); - std::vector> paddings = - ::xla::MakePadding(input_shape, window_shape, strides, - ::xla::Padding::kSame); + PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides, + ::xla::Padding::kSame); + return paddings; +} + +// Same as GetReduceWindowPaddingAsArray but returns padding as +// DenseIntElementsAttr. Returns empty attribute for `VALID` padding. +template +static DenseIntElementsAttr GetReduceWindowPaddingAsAttr( + llvm::ArrayRef input_dims, ArrayAttr window_dims, + ArrayAttr window_strides, StringRef padding, Builder *builder) { + if (padding == "VALID") return {}; + assert(padding == "SAME"); + PaddingArray paddings = GetReduceWindowPaddingAsArray( + input_dims, window_dims, window_strides, padding, builder); int64_t rank = paddings.size(); llvm::SmallVector flatten_paddings(rank * 2); for (int i = 0; i < rank; i++) { @@ -1774,8 +1791,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); DenseIntElementsAttr paddings_attr = - GetReduceWindowPadding<4>(input_type.getShape(), op.ksize(), - op.strides(), op.padding(), &rewriter); + GetReduceWindowPaddingAsAttr<4>(input_type.getShape(), op.ksize(), + op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), @@ -1807,6 +1824,229 @@ class ConvertAvgPoolOp : public OpRewritePattern { } }; +// `AvgPoolGradOp` is converted to the following operations: +// 1. Divide each entry of the output gradient (the gradient for the previous +// layer in backpropagation order) by the count of the corresponding window +// (i.e., the number of non-padding entries of the window which `AvgPool` +// has mapped to this entry in forward propagation). +// 2. Add appropriate interior and exterior padding for step 3 (see example +// below). +// 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape +// as windows) and stride 1 in each dimension. This is implemented as a +// `ReduceWindowOp` with `AddOp` as body. +// +// Example: +// Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2, +// and SAME padding with 0's. It is defined by +// f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) ) +// [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding) +// Note that for SAME padding in `AvgPool` the padded entries are not counted +// for the average, this is why the second denominator is 2 and not 3. +// The Jacobian Df is +// [ 1/3 1/3 1/3 0 ] +// [ 0 0 1/2 1/2 ] +// +// Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only +// needs the original input shape and not the tensor as argument). +// Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the +// average pool gradient is given by +// Df^T * v = [ 4/3 4/3 13/3 3 ]^T +// Instead of a matrix-vector-multiplication we can utilize the sparsity and +// structure of Df by using the 3-step approach from above: +// 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T +// 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T +// 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T +// +// Note that the padding in step 2. is chosen in such a way that the subsequent +// convolution produces the gradient. Higher dimensions, different padding, and +// different windows/strides work in a similar way, the main difference is in +// the computation of the paddings in step 2. +// +// For more details on backpropagation for convolution of which `AvgPoolGrad` +// is a special case see `tensorflow/core/kernels/conv_grad_ops.h`. +// `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more +// examples for different cases. +template +class ConvertAvgPoolGradOp : public OpRewritePattern { + using DimVector = SmallVector; + + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.data_format().str(), &data_format)) { + return failure(); + } + // `out_grad` is the gradient that was propagated via backpropagation from + // the output layer. + Value out_grad = op.grad(); + auto out_grad_type = + out_grad.getType().template dyn_cast(); + if (!out_grad_type) { + return failure(); + } + Type element_type = out_grad_type.getElementType(); + DenseIntElementsAttr orig_input_shape_attr; + if (!matchPattern(op.orig_input_shape(), + m_Constant(&orig_input_shape_attr))) { + return failure(); + } + auto orig_input_shape_values = orig_input_shape_attr.getValues(); + DimVector orig_input_shape(orig_input_shape_values.begin(), + orig_input_shape_values.end()); + RankedTensorType orig_input_type = + RankedTensorType::get(orig_input_shape, element_type); + DimVector ksize, strides; + GetI64ArrayAttrValues(op.ksize(), &ksize); + GetI64ArrayAttrValues(op.strides(), &strides); + Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); + + Operation *out_grad_divided = nullptr; + if (op.padding() == "VALID") { + // All window counts are equal here because we don't have padding + // (each entry of `out_grad` corresponds to a window that consists of + // original input entries only). + int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, + std::multiplies()); + // Divide `out_grad` by window counts. + Value divisor = + GetScalarConstOfType(element_type, loc, window_count, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + out_grad_divided = rewriter.create( + loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims); + } else { + assert(op.padding() == "SAME"); + // For SAME padding, only original entries that contributed to a window + // are counted for the average of this window, not padded entries. + + // Build all-ones tensor of same shape as the original input. + ElementsAttr splat = xla::getSplat(&rewriter, orig_input_type, 1); + auto all_ones_tensor = rewriter.create(loc, splat); + + // Get the same padding as for the original input. + DenseIntElementsAttr orig_padding_attr = + GetReduceWindowPaddingAsAttr(orig_input_shape, op.ksize(), + op.strides(), op.padding(), + &rewriter); + + // Count the 1's in each window, using the same padding as for the + // original input, which gives us the window counts by which `out_grad` + // needs to be divided. + auto window_counts = rewriter.create( + loc, out_grad_type, + /*operand=*/all_ones_tensor, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), + /*window_strides=*/GetI64ElementsAttr(op.strides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/orig_padding_attr); + BuildReduceBody(element_type, &window_counts.body(), &rewriter); + + // Divide `out_grad` by window counts. + out_grad_divided = rewriter.create( + loc, out_grad_type, out_grad, window_counts); + } + + // Get same padding as for original input. + PaddingArray orig_padding = GetReduceWindowPaddingAsArray( + orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter); + + // Add padding around `out_grad_divided` values in such a way that the + // subsequent `ReduceWindowOp` produces the gradient. + DimVector out_grad_shape( + llvm::to_vector(out_grad_type.getShape())); + DimVector low_padding(num_dims, 0); + DimVector high_padding(num_dims, 0); + DimVector interior_padding(num_dims, 0); + constexpr int num_spatial_dims = num_dims - 2; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); + int orig_input_shape_padded_in_dim = orig_input_shape[dim] + + orig_padding[dim].first + + orig_padding[dim].second; + // Set interior padding such that neighboring entries from + // `out_grad_divided` have distance `strides[dim]` from each other in + // every dimension. + interior_padding[dim] = strides[dim] - 1; + // Set exterior padding in the same way as for convolution gradient + // computation. + auto status = ::xla::ConvGradExtractAndVerifyDimension( + /*input_size=*/orig_input_shape_padded_in_dim, + /*filter_size=*/ksize[dim], + /*output_size=*/out_grad_shape[dim], + /*dilation=*/1, + /*stride=*/strides[dim], + /*padding=*/::xla::Padding::kValid); + if (!status.ok()) { + return failure(); + } + ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim = + status.ValueOrDie(); + // Subtract the original exterior padding since it doesn't contribute to + // the gradient. Note that we save one `PadOp` and some unnecessary kernel + // computations, compared to the `xla::AvgPoolGrad` implementation, by + // subtracting the original exterior padding before `ReduceWindowOp` + // instead of trimming the result of `ReduceWindowOp` (the final result is + // the same because all strides are 1). + low_padding[dim] = + conv_grad_spatial_dim.pad_before - orig_padding[dim].first; + high_padding[dim] = + conv_grad_spatial_dim.pad_after - orig_padding[dim].second; + + // Update `out_grad_shape` to result shape of following `PadOp`. + out_grad_shape[dim] = low_padding[dim] + high_padding[dim] + + (out_grad_shape[dim] - 1) * strides[dim] + 1; + } + Value reduce_window_input = rewriter.create( + loc, RankedTensorType::get(out_grad_shape, element_type), + /*operand=*/out_grad_divided->getOpResult(0), + /*padding_value=*/zero, + /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter), + /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter), + /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter)); + + // Compute result by convolving `reduce_window_input` with an all-ones + // kernel, using `ReduceWindowOp` with `AddOp` body. + + Type sum_element_type = GetSumAccumulationType(element_type); + if (element_type != sum_element_type) { + // Convert to appropriate sum accumulation type to avoid precision loss. + reduce_window_input = rewriter.create(loc, reduce_window_input, + sum_element_type); + zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter); + } + auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter); + auto reduce_window_op = rewriter.create( + loc, RankedTensorType::get(orig_input_shape, sum_element_type), + /*operand=*/reduce_window_input, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), + /*window_strides=*/ones, + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/DenseIntElementsAttr()); + BuildReduceBody(sum_element_type, &reduce_window_op.body(), + &rewriter); + Value result = reduce_window_op.getResult(); + + if (element_type != sum_element_type) { + // Convert back to original element type. + result = rewriter.create(op.getLoc(), result, element_type); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +using ConvertAvgPool2DGradOp = + ConvertAvgPoolGradOp; +using ConvertAvgPool3DGradOp = + ConvertAvgPoolGradOp; + // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window // dimensions with max as the reduction function. // @@ -1831,7 +2071,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { auto input_ty = op.input().getType().template dyn_cast(); if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( loc, op.getType(), op.input(), init.getResult(), @@ -3381,7 +3621,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { auto input_ty = op.orig_input().getType().template dyn_cast(); if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto result = rewriter.create( @@ -5202,11 +5442,12 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, - ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, - ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, - ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, - ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, ConvertRangeOp, - ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertAvgPool2DGradOp, + ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, + ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, + ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op, + ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 42353451408..0b5a6c147dc 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -934,6 +934,7 @@ tf_xla_py_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -953,6 +954,7 @@ tf_xla_py_test( name = "pooling_ops_3d_test", size = "medium", srcs = ["pooling_ops_3d_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 9a008940fa2..a833daa39be 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops @@ -74,6 +75,9 @@ class Pooling3DTest(xla_test.XLATestCase): actual = vals.flatten() self.assertAllClose(expected, actual) + @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" + " doesn't support all paddings and data " + "formats") def testAvgPool3dValidPadding(self): expected_output = [20.5, 21.5, 22.5] self._VerifyValues( @@ -84,6 +88,9 @@ class Pooling3DTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) + @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" + " doesn't support all paddings and data " + "formats") def testAvgPool3dSamePadding(self): expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5] self._VerifyValues( @@ -94,6 +101,9 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", expected=expected_output) + @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" + " doesn't support all paddings and data " + "formats") def testAvgPool3dSamePaddingDifferentStrides(self): expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5] self._VerifyValues( @@ -154,6 +164,8 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", expected=expected_output.flatten()) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testKernelSmallerThanStride(self): self._VerifyValues( nn_ops.max_pool3d, @@ -311,6 +323,8 @@ class Pooling3DTest(xla_test.XLATestCase): atol=1e-6) self.assertShapeEqual(actual_grad_gradients_vals, outputs) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -321,6 +335,8 @@ class Pooling3DTest(xla_test.XLATestCase): padding="VALID", pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -343,6 +359,8 @@ class Pooling3DTest(xla_test.XLATestCase): strides=[1, 1, 1], padding="VALID") + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradValidPadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -353,6 +371,8 @@ class Pooling3DTest(xla_test.XLATestCase): padding="VALID", pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -363,6 +383,8 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -373,6 +395,8 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -383,6 +407,8 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index bcc5ce77ec6..293e1010b08 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops @@ -267,6 +268,9 @@ class PoolingTest(xla_test.XLATestCase): expected=[1, 3, 9, 11]) # Average pooling + @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" + " doesn't support all paddings and data " + "formats") def testAvgPoolValidPadding(self): expected_output = [7, 8, 9] self._VerifyValues( @@ -277,6 +281,9 @@ class PoolingTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) + @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" + " doesn't support all paddings and data " + "formats") def testAvgPoolSamePadding(self): expected_output = [7., 8., 9., 11.5, 12.5, 13.5] self._VerifyValues( @@ -542,12 +549,20 @@ class PoolGradTest(xla_test.XLATestCase): padding="SAME", pool_grad_grad_func=pool_grad_grad_func) + @test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for " + "MaxPoolGradGrad op in MLIR-based bridge") def testMaxPool(self): self._TestPooling( nn_ops.max_pool, gen_nn_ops.max_pool_grad, pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad) + # TODO(b/159845178): Remove this once MLIR bridge supports MaxPoolGradGrad + # (then `testMaxPool` test will be sufficient) + def testMaxPoolNoGradGrad(self): + self._TestPooling( + nn_ops.max_pool, gen_nn_ops.max_pool_grad, pool_grad_grad_func=None) + def testAvgPool(self): # Wrapper around AvgPoolGrad that ignores extra arguments needed by # MaxPoolGrad.