Handle implicit ellipsis_mask in tf.StridedSliceOp.
For tf.StridedSliceOp, if no ellipsis_mask exists then an implicit ellipsis_mask at the end is assumed. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. PiperOrigin-RevId: 299152429 Change-Id: I6726f5e233c6b251ab9c5b716c8ebf10f490d318
This commit is contained in:
parent
6506145ebb
commit
d2a436fed9
@ -2621,9 +2621,8 @@ constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
|
|||||||
// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
|
// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
|
||||||
// 4, 8) would have dims = 2.
|
// 4, 8) would have dims = 2.
|
||||||
struct SparseSliceSpec {
|
struct SparseSliceSpec {
|
||||||
const int64_t dims;
|
int64_t dims;
|
||||||
const int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask,
|
int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
|
||||||
shrink_axis_mask;
|
|
||||||
const ArrayRef<int64_t> &begin;
|
const ArrayRef<int64_t> &begin;
|
||||||
const ArrayRef<int64_t> &end;
|
const ArrayRef<int64_t> &end;
|
||||||
const ArrayRef<int64_t> &strides;
|
const ArrayRef<int64_t> &strides;
|
||||||
@ -2783,6 +2782,14 @@ static void CalculateSlicedShapeFromSparseIndices(
|
|||||||
ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||||
sparse_begin, sparse_end, sparse_strides};
|
sparse_begin, sparse_end, sparse_strides};
|
||||||
|
|
||||||
|
// If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
|
||||||
|
// inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
|
||||||
|
// a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
|
||||||
|
if (sparse.ellipsis_mask == 0) {
|
||||||
|
Set(sparse.ellipsis_mask, sparse.dims);
|
||||||
|
sparse.dims++;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t dims = input_shape.size();
|
int64_t dims = input_shape.size();
|
||||||
DenseSliceSpec dense = {dims,
|
DenseSliceSpec dense = {dims,
|
||||||
/*begin_mask = */ 0,
|
/*begin_mask = */ 0,
|
||||||
|
@ -2415,6 +2415,25 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: strided_slice_implicit_ellipsis_mask(
|
||||||
|
// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32>
|
||||||
|
func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> {
|
||||||
|
// StridedSlice gets input[8:10], which is same as input[8:10, ...]
|
||||||
|
// The start_indices, limit_indices, and strides attribute of xla_hlo.slice
|
||||||
|
// reflect the canonicalized slice.
|
||||||
|
%begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
// CHECK: [[SLICE:%.*]] = "xla_hlo.slice"([[INPUT]])
|
||||||
|
// CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64>
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64>
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64>
|
||||||
|
// CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32>
|
||||||
|
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32>
|
||||||
|
// CHECK: return [[RESHAPE]] : tensor<2x16x2xf32>
|
||||||
|
return %0 : tensor<2x16x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reduction op legalizations.
|
// Reduction op legalizations.
|
||||||
|
Loading…
Reference in New Issue
Block a user