diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index fa1394884bf..513567116bc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1227,7 +1227,7 @@ func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> // CHECK-LABEL: maxpool_same_padding // CHECK-SAME: %[[ARG:.*]]: tensor func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { - // CHECK: padding = dense<{{\[\[}}0, 0, 1, 0], [0, 1, 1, 0]]> : tensor<2x4xi64> + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> return %0 : tensor<2x4x7x7xi32> @@ -1263,7 +1263,7 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te // CHECK-LABEL: @max_pool_grad_same func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { - // CHECK: padding = dense<{{\[\[}}0, 0, 1, 0], [0, 1, 1, 0]]> : tensor<2x4xi64> + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { data_format = "NHWC", ksize = [1, 2, 3, 1], diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 30be7fe9fc8..e14f6a20d79 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -997,11 +997,11 @@ static DenseIntElementsAttr GetReduceWindowPadding( int64_t rank = paddings.size(); llvm::SmallVector flatten_paddings(rank * 2); for (int i = 0; i < rank; i++) { - flatten_paddings[i] = paddings[i].first; - flatten_paddings[rank + i] = paddings[i].second; + flatten_paddings[2 * i] = paddings[i].first; + flatten_paddings[2 * i + 1] = paddings[i].second; } return DenseIntElementsAttr::get( - RankedTensorType::get({2, rank}, builder->getIntegerType(64)), + RankedTensorType::get({rank, 2}, builder->getIntegerType(64)), flatten_paddings); }