diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 95440316294..d9c93d16bd4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2621,9 +2621,8 @@ constexpr void CopyBit(const T &src, unsigned src_index, T &dst, // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2, // 4, 8) would have dims = 2. struct SparseSliceSpec { - const int64_t dims; - const int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, - shrink_axis_mask; + int64_t dims; + int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; const ArrayRef &begin; const ArrayRef &end; const ArrayRef &strides; @@ -2783,6 +2782,14 @@ static void CalculateSlicedShapeFromSparseIndices( ellipsis_mask, new_axis_mask, shrink_axis_mask, sparse_begin, sparse_end, sparse_strides}; + // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is + // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields + // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. + if (sparse.ellipsis_mask == 0) { + Set(sparse.ellipsis_mask, sparse.dims); + sparse.dims++; + } + int64_t dims = input_shape.size(); DenseSliceSpec dense = {dims, /*begin_mask = */ 0, diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 03277ff7c37..fa0be85d2fe 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2415,6 +2415,25 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { return } +// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( +// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> +func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { + // StridedSlice gets input[8:10], which is same as input[8:10, ...] + // The start_indices, limit_indices, and strides attribute of xla_hlo.slice + // reflect the canonicalized slice. + %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> + %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: [[SLICE:%.*]] = "xla_hlo.slice"([[INPUT]]) + // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> + // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> + // CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> + // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> + return %0 : tensor<2x16x2xf32> +} + //===----------------------------------------------------------------------===// // Reduction op legalizations.