Add new_axis_mask support to tf.StridedSlice lowering.

To add new_axis_mask support we only need to keep track for the number of new axis after ellipsis. This helps in counting the number of dimensions ellipsis represents and hence produce the correct attributes of  `xla_hlo.slice` op. The result of slice is reshaped to reflect new axes added.

PiperOrigin-RevId: 294342718
Change-Id: Ic2fc3d9116fd2160a5341826180b1b0bdc231698
This commit is contained in:
Prakalp Srivastava 2020-02-10 17:32:46 -08:00 committed by TensorFlower Gardener
parent 12b968f470
commit f78f9898d9
2 changed files with 47 additions and 8 deletions

View File

@ -2356,15 +2356,23 @@ static LogicalResult BuildDenseSliceSpec(const SparseSliceSpec &sparse,
dense->end_mask = 0;
dense->shrink_axis_mask = 0;
// Count number of new_axis after ellipsis. This helps in calculating the
// number of dimensions ellipsis represents in the sparse spec.
bool ellipsis_seen = false;
int num_new_axis_after_ellipsis = 0;
for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
num_new_axis_after_ellipsis++;
if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
}
int dense_index = 0;
for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
if (IsSet(sparse.new_axis_mask, sparse_index)) {
// TODO(b/146512589): Add support for new_axis_mask.
continue;
}
if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
if (IsSet(sparse.ellipsis_mask, sparse_index)) {
auto next_index =
std::min(dense->dims - (sparse.dims - sparse_index) + 1, dense->dims);
auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1 + num_new_axis_after_ellipsis,
dense->dims);
// Expand ellipsis into the appropriate dense indices. From current index
// until next_index, all dimensions would have begin and end masks set and
// stride 1, i.e., get all elements in those dimensions.
@ -2461,8 +2469,6 @@ static void CalculateSlicedShapeAndBoundRanges(
bool StridedSliceOp::GetSlicedBoundRanges(
SmallVectorImpl<int64_t> *begin_indices,
SmallVectorImpl<int64_t> *end_indices, SmallVectorImpl<int64_t> *strides) {
if (this->new_axis_mask().getZExtValue())
return false; // TODO(b/146512589): support these masks
// TODO(hinsu): Support lowering for ops with dynamic begin and end values
// when it is possible to derive indices based on mask attributes.
DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;

View File

@ -2375,6 +2375,39 @@ func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
return
}
// CHECK-LABEL: strided_slice_new_axis_mask
// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
// For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
// New axis mask is at index 1 and 6 of sparse spec, so
// new_axis_mask = 2^1 + 2^6 = 66
// The ellipsis mask is applied to dim #1, #2 of input i.e, we get
// canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
// This is then reshaped to add the new axes.
// The start, limit indices and strides attributes of xla_hlo.slice would
// reflect the canonicalized slice.
// As output shape of StridedSlice differs, a reshape will follow to reflect
// new axes added.
%begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
%end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
%strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
// CHECK: %[[SLICE:.*]] = "xla_hlo.slice"(%[[INPUT]])
// CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
// CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
// CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
// CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32>
// CHECK: "xla_hlo.reshape"(%[[SLICE]])
// CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32>
return
}
//===----------------------------------------------------------------------===//
// Reduction op legalizations.
//===----------------------------------------------------------------------===//