Added TF-to-XLA_HLO lowering for AvgPoolGrad

Also added required ops and MLIR unit tests.

PiperOrigin-RevId: 318199300
Change-Id: I46bf921b5a14c1c4428bfdf51d3e3415a3af65bc
This commit is contained in:
Michael Gester 2020-06-24 20:59:56 -07:00 committed by TensorFlower Gardener
parent 896c922092
commit 772b836fdb
7 changed files with 580 additions and 20 deletions

View File

@ -745,6 +745,29 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
let summary = "Computes gradients of the 3D average pooling function.";
let description = [{
}];
let arguments = (ins
I32Tensor:$orig_input_shape,
TF_FpTensor:$grad,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> {
let summary = "Multiplies slices of two tensors in batches.";

View File

@ -175,6 +175,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/compiler/xla/client/lib:conv_grad_size_util",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"//tensorflow/core/lib/bfloat16",

View File

@ -1096,7 +1096,6 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te
// CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
// CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32>
// CHECK: }
%result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
data_format = "NHWC",
ksize = [1, 2, 2, 1],
@ -1120,7 +1119,6 @@ func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_outpu
// CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
// CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32>
// CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32>
// CHECK: }
%result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32>
return %result : tensor<10x8x24x24x64xf32>
}
@ -3838,6 +3836,260 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
return %0 : tensor<2x4x7x7xf32>
}
//===----------------------------------------------------------------------===//
// AvgPoolGrad op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @avgpool_grad_valid_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
// CHECK_SAME: broadcast_dimensions = dense<[]>
// CHECK_SAME: -> tensor<10x12x16x64xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]>
// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]>
// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]>
// CHECK-SAME: -> tensor<10x25x33x64xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]>
// CHECK-SAME: window_strides = dense<1>
// CHECK-SAME: -> tensor<10x24x32x64xf32>
// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32>
func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
%orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
%result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
data_format = "NHWC",
ksize = [1, 2, 2, 1],
padding = "VALID",
strides = [1, 2, 2, 1]
} : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32>
return %result : tensor<10x24x32x64xf32>
}
// CHECK-LABEL: func @avgpool_3d_grad_valid_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]>
// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]>
// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]>
// CHECK-SAME: -> tensor<10x8x25x33x64xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]>
// CHECK-SAME: window_strides = dense<1>
// CHECK-SAME: -> tensor<10x8x24x32x64xf32>
// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32>
func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
%orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>)
%result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
data_format = "NDHWC",
ksize = [1, 1, 2, 2, 1],
padding = "VALID",
strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32>
return %result : tensor<10x8x24x32x64xf32>
}
// CHECK-LABEL: func @avgpool_grad_same_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]>
// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]>
// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]>
// CHECK-SAME: -> tensor<2x4x7x9xf32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]>
// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]>
// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]>
// CHECK-SAME: -> tensor<2x14x27x9xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]>
// CHECK-SAME: window_strides = dense<1>
// CHECK-SAME: -> tensor<2x13x25x9xf32>
// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32>
func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
%orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>)
%result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
data_format = "NHWC",
ksize = [1, 2, 3, 1],
padding = "SAME",
strides = [1, 4, 4, 1]
} : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32>
return %result : tensor<2x13x25x9xf32>
}
// CHECK-LABEL: func @avgpool_3d_grad_same_padding(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]>
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]>
// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]>
// CHECK-SAME: -> tensor<2x8x4x7x9xf32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]>
// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]>
// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]>
// CHECK-SAME: -> tensor<2x8x14x27x9xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]>
// CHECK-SAME: window_strides = dense<1>
// CHECK-SAME: -> tensor<2x8x13x25x9xf32>
// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32>
func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
%orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>)
%result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
data_format = "NDHWC",
ksize = [1, 1, 2, 3, 1],
padding = "SAME",
strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32>
return %result : tensor<2x8x13x25x9xf32>
}
// CHECK-LABEL: func @avgpool_grad_nchw_format(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]>
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]>
// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]>
// CHECK-SAME: -> tensor<2x9x4x7xf32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]>
// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]>
// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]>
// CHECK-SAME: -> tensor<2x9x14x27xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]>
// CHECK-SAME: window_strides = dense<1>
// CHECK-SAME: -> tensor<2x9x13x25xf32>
// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32>
func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
%orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>)
%result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
data_format = "NCHW",
ksize = [1, 1, 2, 3],
padding = "SAME",
strides = [1, 1, 4, 4]
} : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32>
return %result : tensor<2x9x13x25xf32>
}
// CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[ALL_ONES:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
// CHECK: %[[DIVISOR:.*]] = "xla_hlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM1:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]>
// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]>
// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]>
// CHECK-SAME: -> tensor<2x9x8x4x7xf32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_hlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]>
// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]>
// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]>
// CHECK-SAME: -> tensor<2x9x8x14x27xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
// CHECK: %[[SUM2:.*]] = xla_hlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]>
// CHECK-SAME: window_strides = dense<1> : tensor<5xi64>
// CHECK-SAME: -> tensor<2x9x8x13x25xf32>
// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32>
func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
%orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>)
%result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
data_format = "NCDHW",
ksize = [1, 1, 1, 2, 3],
padding = "SAME",
strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32>
return %result : tensor<2x9x8x13x25xf32>
}
// CHECK-LABEL: func @avgpool_grad_bf16(
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<bf16>
// CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<4.000000e+00> : tensor<bf16>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
// CHECK-SAME: broadcast_dimensions = dense<[]>
// CHECK-SAME: -> tensor<10x12x16x64xbf16>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "xla_hlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]>
// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]>
// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]>
// CHECK-SAME: -> tensor<10x25x33x64xbf16>
// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "xla_hlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32>
// CHECK: %[[ZERO_F32:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[SUM:.*]] = xla_hlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: "xla_hlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]>
// CHECK-SAME: window_strides = dense<1>
// CHECK-SAME: -> tensor<10x24x32x64xf32>
// CHECK: %[[RESULT_CONVERTED:.*]] = "xla_hlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16>
// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16>
func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
%orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
%result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
data_format = "NHWC",
ksize = [1, 2, 2, 1],
padding = "VALID",
strides = [1, 2, 2, 1]
} : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16>
return %result : tensor<10x24x32x64xbf16>
}
// CHECK-LABEL: xla_sharding
func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, xla_hlo.sharding = ""}

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -570,7 +571,7 @@ Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
dimension_numbers, precision_config);
}
// Builds body for reduce op by using the using the template binary op as the
// Builds body for reduce op by using the template binary op as the
// reducer op.
template <typename Op>
static void BuildReduceBody(Type element_type, Region *body,
@ -1697,17 +1698,21 @@ class ConvertFusedBatchNormV3Op
}
};
// Returns padding attribute for ReduceWindow op with given params.
using PaddingArray =
std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
// Returns padding values for ReduceWindow op as a vector of pairs.
//
// Requires padding to be either 'SAME' or 'VALID' and the number of input
// dimensions to be equal to the size of window dimensions and window strides.
template <int num_dims>
static DenseIntElementsAttr GetReduceWindowPadding(
static PaddingArray GetReduceWindowPaddingAsArray(
llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
ArrayAttr window_strides, StringRef padding, Builder *builder) {
if (padding == "VALID") return {};
DCHECK_EQ(padding.str(), "SAME");
if (padding == "VALID") {
return PaddingArray(num_dims, std::make_pair(0, 0));
}
assert(padding == "SAME");
llvm::SmallVector<tensorflow::int64, num_dims> input_shape, window_shape,
strides;
input_shape.reserve(input_dims.size());
@ -1720,9 +1725,21 @@ static DenseIntElementsAttr GetReduceWindowPadding(
for (Attribute attr : window_strides)
strides.push_back(attr.cast<IntegerAttr>().getInt());
std::vector<std::pair<tensorflow::int64, tensorflow::int64>> paddings =
::xla::MakePadding(input_shape, window_shape, strides,
::xla::Padding::kSame);
PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
::xla::Padding::kSame);
return paddings;
}
// Same as GetReduceWindowPaddingAsArray but returns padding as
// DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
template <int num_dims>
static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
ArrayAttr window_strides, StringRef padding, Builder *builder) {
if (padding == "VALID") return {};
assert(padding == "SAME");
PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
input_dims, window_dims, window_strides, padding, builder);
int64_t rank = paddings.size();
llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
for (int i = 0; i < rank; i++) {
@ -1774,8 +1791,8 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
Value init =
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
DenseIntElementsAttr paddings_attr =
GetReduceWindowPadding<4>(input_type.getShape(), op.ksize(),
op.strides(), op.padding(), &rewriter);
GetReduceWindowPaddingAsAttr<4>(input_type.getShape(), op.ksize(),
op.strides(), op.padding(), &rewriter);
auto reduce = rewriter.create<ReduceWindowOp>(
op.getLoc(), result_type, input_value, init,
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
@ -1807,6 +1824,229 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
}
};
// `AvgPoolGradOp` is converted to the following operations:
// 1. Divide each entry of the output gradient (the gradient for the previous
// layer in backpropagation order) by the count of the corresponding window
// (i.e., the number of non-padding entries of the window which `AvgPool`
// has mapped to this entry in forward propagation).
// 2. Add appropriate interior and exterior padding for step 3 (see example
// below).
// 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
// as windows) and stride 1 in each dimension. This is implemented as a
// `ReduceWindowOp` with `AddOp` as body.
//
// Example:
// Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
// and SAME padding with 0's. It is defined by
// f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) )
// [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding)
// Note that for SAME padding in `AvgPool` the padded entries are not counted
// for the average, this is why the second denominator is 2 and not 3.
// The Jacobian Df is
// [ 1/3 1/3 1/3 0 ]
// [ 0 0 1/2 1/2 ]
//
// Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
// needs the original input shape and not the tensor as argument).
// Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the
// average pool gradient is given by
// Df^T * v = [ 4/3 4/3 13/3 3 ]^T
// Instead of a matrix-vector-multiplication we can utilize the sparsity and
// structure of Df by using the 3-step approach from above:
// 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T
// 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T
// 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T
//
// Note that the padding in step 2. is chosen in such a way that the subsequent
// convolution produces the gradient. Higher dimensions, different padding, and
// different windows/strides work in a similar way, the main difference is in
// the computation of the paddings in step 2.
//
// For more details on backpropagation for convolution of which `AvgPoolGrad`
// is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
// `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
// examples for different cases.
template <typename OpTy, int num_dims>
class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
using DimVector = SmallVector<int64_t, num_dims>;
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
tensorflow::TensorFormat data_format;
if (!FormatFromString(op.data_format().str(), &data_format)) {
return failure();
}
// `out_grad` is the gradient that was propagated via backpropagation from
// the output layer.
Value out_grad = op.grad();
auto out_grad_type =
out_grad.getType().template dyn_cast<RankedTensorType>();
if (!out_grad_type) {
return failure();
}
Type element_type = out_grad_type.getElementType();
DenseIntElementsAttr orig_input_shape_attr;
if (!matchPattern(op.orig_input_shape(),
m_Constant(&orig_input_shape_attr))) {
return failure();
}
auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
DimVector orig_input_shape(orig_input_shape_values.begin(),
orig_input_shape_values.end());
RankedTensorType orig_input_type =
RankedTensorType::get(orig_input_shape, element_type);
DimVector ksize, strides;
GetI64ArrayAttrValues(op.ksize(), &ksize);
GetI64ArrayAttrValues(op.strides(), &strides);
Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
Operation *out_grad_divided = nullptr;
if (op.padding() == "VALID") {
// All window counts are equal here because we don't have padding
// (each entry of `out_grad` corresponds to a window that consists of
// original input entries only).
int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
std::multiplies<int64_t>());
// Divide `out_grad` by window counts.
Value divisor =
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
out_grad_divided = rewriter.create<xla_chlo::BroadcastDivOp>(
loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims);
} else {
assert(op.padding() == "SAME");
// For SAME padding, only original entries that contributed to a window
// are counted for the average of this window, not padded entries.
// Build all-ones tensor of same shape as the original input.
ElementsAttr splat = xla::getSplat(&rewriter, orig_input_type, 1);
auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
// Get the same padding as for the original input.
DenseIntElementsAttr orig_padding_attr =
GetReduceWindowPaddingAsAttr<num_dims>(orig_input_shape, op.ksize(),
op.strides(), op.padding(),
&rewriter);
// Count the 1's in each window, using the same padding as for the
// original input, which gives us the window counts by which `out_grad`
// needs to be divided.
auto window_counts = rewriter.create<ReduceWindowOp>(
loc, out_grad_type,
/*operand=*/all_ones_tensor,
/*init_value=*/zero,
/*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
/*window_strides=*/GetI64ElementsAttr(op.strides()),
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(),
/*padding=*/orig_padding_attr);
BuildReduceBody<AddOp>(element_type, &window_counts.body(), &rewriter);
// Divide `out_grad` by window counts.
out_grad_divided = rewriter.create<xla_hlo::DivOp>(
loc, out_grad_type, out_grad, window_counts);
}
// Get same padding as for original input.
PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
// Add padding around `out_grad_divided` values in such a way that the
// subsequent `ReduceWindowOp` produces the gradient.
DimVector out_grad_shape(
llvm::to_vector<num_dims>(out_grad_type.getShape()));
DimVector low_padding(num_dims, 0);
DimVector high_padding(num_dims, 0);
DimVector interior_padding(num_dims, 0);
constexpr int num_spatial_dims = num_dims - 2;
for (int i = 0; i < num_spatial_dims; ++i) {
int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
orig_padding[dim].first +
orig_padding[dim].second;
// Set interior padding such that neighboring entries from
// `out_grad_divided` have distance `strides[dim]` from each other in
// every dimension.
interior_padding[dim] = strides[dim] - 1;
// Set exterior padding in the same way as for convolution gradient
// computation.
auto status = ::xla::ConvGradExtractAndVerifyDimension(
/*input_size=*/orig_input_shape_padded_in_dim,
/*filter_size=*/ksize[dim],
/*output_size=*/out_grad_shape[dim],
/*dilation=*/1,
/*stride=*/strides[dim],
/*padding=*/::xla::Padding::kValid);
if (!status.ok()) {
return failure();
}
::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
status.ValueOrDie();
// Subtract the original exterior padding since it doesn't contribute to
// the gradient. Note that we save one `PadOp` and some unnecessary kernel
// computations, compared to the `xla::AvgPoolGrad` implementation, by
// subtracting the original exterior padding before `ReduceWindowOp`
// instead of trimming the result of `ReduceWindowOp` (the final result is
// the same because all strides are 1).
low_padding[dim] =
conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
high_padding[dim] =
conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
// Update `out_grad_shape` to result shape of following `PadOp`.
out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
(out_grad_shape[dim] - 1) * strides[dim] + 1;
}
Value reduce_window_input = rewriter.create<PadOp>(
loc, RankedTensorType::get(out_grad_shape, element_type),
/*operand=*/out_grad_divided->getOpResult(0),
/*padding_value=*/zero,
/*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
/*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
/*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
// Compute result by convolving `reduce_window_input` with an all-ones
// kernel, using `ReduceWindowOp` with `AddOp` body.
Type sum_element_type = GetSumAccumulationType(element_type);
if (element_type != sum_element_type) {
// Convert to appropriate sum accumulation type to avoid precision loss.
reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
sum_element_type);
zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
}
auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
auto reduce_window_op = rewriter.create<ReduceWindowOp>(
loc, RankedTensorType::get(orig_input_shape, sum_element_type),
/*operand=*/reduce_window_input,
/*init_value=*/zero,
/*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
/*window_strides=*/ones,
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(),
/*padding=*/DenseIntElementsAttr());
BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
&rewriter);
Value result = reduce_window_op.getResult();
if (element_type != sum_element_type) {
// Convert back to original element type.
result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
}
rewriter.replaceOp(op, {result});
return success();
}
};
using ConvertAvgPool2DGradOp =
ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
using ConvertAvgPool3DGradOp =
ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
// dimensions with max as the reduction function.
//
@ -1831,7 +2071,7 @@ class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return failure();
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<num_dims>(
DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
auto reduce = rewriter.create<ReduceWindowOp>(
loc, op.getType(), op.input(), init.getResult(),
@ -3381,7 +3621,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
auto input_ty =
op.orig_input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return failure();
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding<num_dims>(
DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
auto result = rewriter.create<SelectAndScatterOp>(
@ -5202,11 +5442,12 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, ConvertRangeOp,
ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertAvgPool2DGradOp,
ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,

View File

@ -934,6 +934,7 @@ tf_xla_py_test(
name = "pooling_ops_test",
size = "medium",
srcs = ["pooling_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 10,
tags = [
@ -953,6 +954,7 @@ tf_xla_py_test(
name = "pooling_ops_3d_test",
size = "medium",
srcs = ["pooling_ops_3d_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 10,
tags = [

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
@ -74,6 +75,9 @@ class Pooling3DTest(xla_test.XLATestCase):
actual = vals.flatten()
self.assertAllClose(expected, actual)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPool3dValidPadding(self):
expected_output = [20.5, 21.5, 22.5]
self._VerifyValues(
@ -84,6 +88,9 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="VALID",
expected=expected_output)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPool3dSamePadding(self):
expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5]
self._VerifyValues(
@ -94,6 +101,9 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
expected=expected_output)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPool3dSamePaddingDifferentStrides(self):
expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5]
self._VerifyValues(
@ -154,6 +164,8 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
expected=expected_output.flatten())
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testKernelSmallerThanStride(self):
self._VerifyValues(
nn_ops.max_pool3d,
@ -311,6 +323,8 @@ class Pooling3DTest(xla_test.XLATestCase):
atol=1e-6)
self.assertShapeEqual(actual_grad_gradients_vals, outputs)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradValidPadding1_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -321,6 +335,8 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="VALID",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradValidPadding2_1_6_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -343,6 +359,8 @@ class Pooling3DTest(xla_test.XLATestCase):
strides=[1, 1, 1],
padding="VALID")
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradValidPadding2_2_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -353,6 +371,8 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="VALID",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding1_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -363,6 +383,8 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding2_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -373,6 +395,8 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding2_2_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -383,6 +407,8 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding3_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
@ -267,6 +268,9 @@ class PoolingTest(xla_test.XLATestCase):
expected=[1, 3, 9, 11])
# Average pooling
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPoolValidPadding(self):
expected_output = [7, 8, 9]
self._VerifyValues(
@ -277,6 +281,9 @@ class PoolingTest(xla_test.XLATestCase):
padding="VALID",
expected=expected_output)
@test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering"
" doesn't support all paddings and data "
"formats")
def testAvgPoolSamePadding(self):
expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
self._VerifyValues(
@ -542,12 +549,20 @@ class PoolGradTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=pool_grad_grad_func)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPool(self):
self._TestPooling(
nn_ops.max_pool,
gen_nn_ops.max_pool_grad,
pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad)
# TODO(b/159845178): Remove this once MLIR bridge supports MaxPoolGradGrad
# (then `testMaxPool` test will be sufficient)
def testMaxPoolNoGradGrad(self):
self._TestPooling(
nn_ops.max_pool, gen_nn_ops.max_pool_grad, pool_grad_grad_func=None)
def testAvgPool(self):
# Wrapper around AvgPoolGrad that ignores extra arguments needed by
# MaxPoolGrad.