Handle implicit ellipsis_mask in tf.StridedSliceOp.
For tf.StridedSliceOp, if no ellipsis_mask exists then an implicit ellipsis_mask at the end is assumed. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. PiperOrigin-RevId: 299152429 Change-Id: I6726f5e233c6b251ab9c5b716c8ebf10f490d318
This commit is contained in:
parent
6506145ebb
commit
d2a436fed9
@ -2621,9 +2621,8 @@ constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
|
||||
// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
|
||||
// 4, 8) would have dims = 2.
|
||||
struct SparseSliceSpec {
|
||||
const int64_t dims;
|
||||
const int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask,
|
||||
shrink_axis_mask;
|
||||
int64_t dims;
|
||||
int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
|
||||
const ArrayRef<int64_t> &begin;
|
||||
const ArrayRef<int64_t> &end;
|
||||
const ArrayRef<int64_t> &strides;
|
||||
@ -2783,6 +2782,14 @@ static void CalculateSlicedShapeFromSparseIndices(
|
||||
ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||
sparse_begin, sparse_end, sparse_strides};
|
||||
|
||||
// If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
|
||||
// inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
|
||||
// a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
|
||||
if (sparse.ellipsis_mask == 0) {
|
||||
Set(sparse.ellipsis_mask, sparse.dims);
|
||||
sparse.dims++;
|
||||
}
|
||||
|
||||
int64_t dims = input_shape.size();
|
||||
DenseSliceSpec dense = {dims,
|
||||
/*begin_mask = */ 0,
|
||||
|
@ -2415,6 +2415,25 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: strided_slice_implicit_ellipsis_mask(
|
||||
// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32>
|
||||
func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> {
|
||||
// StridedSlice gets input[8:10], which is same as input[8:10, ...]
|
||||
// The start_indices, limit_indices, and strides attribute of xla_hlo.slice
|
||||
// reflect the canonicalized slice.
|
||||
%begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: [[SLICE:%.*]] = "xla_hlo.slice"([[INPUT]])
|
||||
// CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64>
|
||||
// CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32>
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32>
|
||||
// CHECK: return [[RESHAPE]] : tensor<2x16x2xf32>
|
||||
return %0 : tensor<2x16x2xf32>
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reduction op legalizations.
|
||||
|
Loading…
Reference in New Issue
Block a user