Handle implicit ellipsis_mask in tf.StridedSliceOp.

For tf.StridedSliceOp, if no ellipsis_mask exists then an implicit ellipsis_mask at the end is assumed. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].

PiperOrigin-RevId: 299152429
Change-Id: I6726f5e233c6b251ab9c5b716c8ebf10f490d318
This commit is contained in:
Prakalp Srivastava 2020-03-05 11:37:36 -08:00 committed by TensorFlower Gardener
parent 6506145ebb
commit d2a436fed9
2 changed files with 29 additions and 3 deletions

View File

@ -2621,9 +2621,8 @@ constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
// 4, 8) would have dims = 2.
struct SparseSliceSpec {
const int64_t dims;
const int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask;
int64_t dims;
int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
const ArrayRef<int64_t> &begin;
const ArrayRef<int64_t> &end;
const ArrayRef<int64_t> &strides;
@ -2783,6 +2782,14 @@ static void CalculateSlicedShapeFromSparseIndices(
ellipsis_mask, new_axis_mask, shrink_axis_mask,
sparse_begin, sparse_end, sparse_strides};
// If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
// inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
// a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
if (sparse.ellipsis_mask == 0) {
Set(sparse.ellipsis_mask, sparse.dims);
sparse.dims++;
}
int64_t dims = input_shape.size();
DenseSliceSpec dense = {dims,
/*begin_mask = */ 0,

View File

@ -2415,6 +2415,25 @@ func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
return
}
// CHECK-LABEL: strided_slice_implicit_ellipsis_mask(
// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32>
func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> {
// StridedSlice gets input[8:10], which is same as input[8:10, ...]
// The start_indices, limit_indices, and strides attribute of xla_hlo.slice
// reflect the canonicalized slice.
%begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
%end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
%strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: [[SLICE:%.*]] = "xla_hlo.slice"([[INPUT]])
// CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64>
// CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64>
// CHECK: [[RESHAPE:%.*]] = "xla_hlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32>
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32>
// CHECK: return [[RESHAPE]] : tensor<2x16x2xf32>
return %0 : tensor<2x16x2xf32>
}
//===----------------------------------------------------------------------===//
// Reduction op legalizations.