diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index e6d4a84f8ba..6c635bd3500 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -453,6 +453,30 @@ func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: return %arg0 : tensor<3xf32> } +// CHECK-LABEL: @StridedSliceEllipsisMaskBefore +func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15x2xf32> { + %cst = constant dense<0> : tensor<2xi32> + %cst_0 = constant dense<1> : tensor<2xi32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32> + return %0 : tensor<21x15x2xf32> + + // CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32> + // CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32> + // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32> +} + +// CHECK-LABEL: @StridedSliceEllipsisMaskAfter +func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> { + %cst = constant dense<0> : tensor<2xi32> + %cst_0 = constant dense<1> : tensor<2xi32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x15x7xf32> + return %0 : tensor<5x15x7xf32> + + // CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32> + // CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32> + // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32> +} + // CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> { %cst = constant dense<0> : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 3419ee22174..920570ae43c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -429,13 +429,9 @@ struct ConvertTFStridedSlice : public RewritePattern { explicit ConvertTFStridedSlice(MLIRContext *context) : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - // TODO(renjieliu): Consider expand the transformation for ellipsis & shrink - // mask as well. + PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask, + PatternRewriter &rewriter) const { TF::StridedSliceOp strided_slice_op = llvm::cast(op); - uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue(); - if (new_axis_mask == 0) return matchFailure(); // Insert a new reshape op. Value original_input = strided_slice_op.input(); @@ -491,6 +487,138 @@ struct ConvertTFStridedSlice : public RewritePattern { strided_slice_op.shrink_axis_mask())); return matchSuccess(); } + + PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask, + PatternRewriter &rewriter) const { + TF::StridedSliceOp strided_slice_op = llvm::cast(op); + + DenseIntElementsAttr begin_dense_elem_attr; + Value begin = strided_slice_op.begin(); + auto begin_ranked_attr_type = begin.getType().dyn_cast(); + if (!begin_ranked_attr_type || + !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) { + return matchFailure(); + } + + DenseIntElementsAttr end_dense_elem_attr; + Value end = strided_slice_op.end(); + auto end_ranked_attr_type = end.getType().dyn_cast(); + if (!end_ranked_attr_type || + !matchPattern(end, m_Constant(&end_dense_elem_attr))) { + return matchFailure(); + } + + DenseIntElementsAttr stride_dense_elem_attr; + Value stride = strided_slice_op.strides(); + auto stride_ranked_attr_type = + stride.getType().dyn_cast(); + if (!stride_ranked_attr_type || + !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) { + return matchFailure(); + } + + Value input = strided_slice_op.input(); + RankedTensorType input_type = input.getType().cast(); + const ArrayRef input_shape = input_type.getShape(); + + const int input_size = input_shape.size(); + + RankedTensorType begin_type = begin.getType().cast(); + const ArrayRef begin_shape = begin_type.getShape(); + const int begin_dim = begin_shape.size(); + + if (begin_dim != 1) return matchFailure(); + + const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1; + + llvm::APInt new_begin_mask = strided_slice_op.begin_mask(); + llvm::APInt new_end_mask = strided_slice_op.end_mask(); + + SmallVector padded_begin; + SmallVector padded_end; + SmallVector padded_stride; + + // Before the ellipsis. + uint64_t index = 1; + int count = 0; + + while (index < ellipsis_mask) { + padded_begin.push_back(begin_dense_elem_attr.getValue(count)); + padded_end.push_back(end_dense_elem_attr.getValue(count)); + padded_stride.push_back(stride_dense_elem_attr.getValue(count)); + index <<= 1; + count++; + } + + // Ellipsis. + for (int i = 0; i < ellipsis_filled_dim_size; ++i) { + new_begin_mask |= ellipsis_mask; + new_end_mask |= ellipsis_mask; + + // Mimic the begin/end/strides mask behavior. + padded_begin.push_back(0); + padded_end.push_back(0); + padded_stride.push_back(1); + + ellipsis_mask <<= 1; + } + + // Account for ellipsis mask. + count++; + + // After the ellipsis. + for (; count < begin_shape[0]; ++count) { + padded_begin.push_back(begin_dense_elem_attr.getValue(count)); + padded_end.push_back(end_dense_elem_attr.getValue(count)); + padded_stride.push_back(stride_dense_elem_attr.getValue(count)); + } + + auto attribute_type = rewriter.getIntegerType(64); + + int full_dim_count = padded_begin.size(); + auto type = + RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32)); + + auto begin_attr = DenseElementsAttr::get(type, padded_begin); + auto begin_op = rewriter.create(op->getLoc(), type, begin_attr); + auto end_attr = DenseElementsAttr::get(type, padded_end); + auto end_op = rewriter.create(op->getLoc(), type, end_attr); + auto stride_attr = DenseElementsAttr::get(type, padded_stride); + auto stride_op = + rewriter.create(op->getLoc(), type, stride_attr); + + rewriter.replaceOpWithNewOp( + op, strided_slice_op.getType(), input, begin_op.getResult(), + end_op.getResult(), stride_op.getResult(), + rewriter.getIntegerAttr(attribute_type, new_begin_mask), + rewriter.getIntegerAttr(attribute_type, new_end_mask), + rewriter.getI64IntegerAttr(0), + rewriter.getIntegerAttr(attribute_type, + strided_slice_op.new_axis_mask()), + rewriter.getIntegerAttr(attribute_type, + strided_slice_op.shrink_axis_mask())); + return matchSuccess(); + } + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // TODO(renjieliu): Consider expand the transformation for shrink + // mask as well. + TF::StridedSliceOp strided_slice_op = llvm::cast(op); + + // Handle new axis mask. + uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue(); + if (new_axis_mask != 0) { + return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter); + } + + // Handle ellipsis mask. + uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue(); + if (ellipsis_mask != 0) { + return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter); + } + return matchFailure(); + } }; #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" diff --git a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py index d046c5c6c01..95f7acabdf7 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py +++ b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py @@ -50,11 +50,23 @@ def make_strided_slice_np_style_tests(options): "shape": [[21, 15, 7]], "spec": [[slice(3, 7, 2), slice(None), 2]], }, - # Ellipsis. + # Ellipsis 3d. { "dtype": [tf.float32], "shape": [[21, 15, 7]], - "spec": [[slice(3, 7, 2), Ellipsis]], + "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis, + slice(3, 7, 2)], + [slice(1, 11, 3), Ellipsis, + slice(3, 7, 2)]], + }, + # Ellipsis 4d. + { + "dtype": [tf.float32], + "shape": [[21, 15, 7, 9]], + "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis, + slice(3, 7, 2)], + [slice(1, 11, 3), Ellipsis, + slice(3, 7, 2)]], }, # All combinations. {