Add ellipsis mask transformation in mlir for tflite.
PiperOrigin-RevId: 294820182 Change-Id: I7dbc5ee793e5e2a65a820847b4b0be4c73090b50
This commit is contained in:
		
							parent
							
								
									7fa9732b8d
								
							
						
					
					
						commit
						f6a8bcd0f2
					
				@ -453,6 +453,30 @@ func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
 | 
				
			|||||||
  // CHECK:  return %arg0 : tensor<3xf32>
 | 
					  // CHECK:  return %arg0 : tensor<3xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @StridedSliceEllipsisMaskBefore
 | 
				
			||||||
 | 
					func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15x2xf32> {
 | 
				
			||||||
 | 
					  %cst = constant dense<0> : tensor<2xi32>
 | 
				
			||||||
 | 
					  %cst_0 = constant dense<1> : tensor<2xi32>
 | 
				
			||||||
 | 
					  %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<21x15x2xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
 | 
				
			||||||
 | 
					  // CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
 | 
				
			||||||
 | 
					  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @StridedSliceEllipsisMaskAfter
 | 
				
			||||||
 | 
					func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> {
 | 
				
			||||||
 | 
					  %cst = constant dense<0> : tensor<2xi32>
 | 
				
			||||||
 | 
					  %cst_0 = constant dense<1> : tensor<2xi32>
 | 
				
			||||||
 | 
					  %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x15x7xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<5x15x7xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
 | 
				
			||||||
 | 
					  // CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
 | 
				
			||||||
 | 
					  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
 | 
					// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
 | 
				
			||||||
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
 | 
					func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
 | 
				
			||||||
  %cst = constant dense<0> : tensor<4xi32>
 | 
					  %cst = constant dense<0> : tensor<4xi32>
 | 
				
			||||||
 | 
				
			|||||||
@ -429,13 +429,9 @@ struct ConvertTFStridedSlice : public RewritePattern {
 | 
				
			|||||||
  explicit ConvertTFStridedSlice(MLIRContext *context)
 | 
					  explicit ConvertTFStridedSlice(MLIRContext *context)
 | 
				
			||||||
      : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
 | 
					      : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  PatternMatchResult matchAndRewrite(Operation *op,
 | 
					  PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
 | 
				
			||||||
                                     PatternRewriter &rewriter) const override {
 | 
					                                        PatternRewriter &rewriter) const {
 | 
				
			||||||
    // TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
 | 
					 | 
				
			||||||
    // mask as well.
 | 
					 | 
				
			||||||
    TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
 | 
					    TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
 | 
				
			||||||
    uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
 | 
					 | 
				
			||||||
    if (new_axis_mask == 0) return matchFailure();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Insert a new reshape op.
 | 
					    // Insert a new reshape op.
 | 
				
			||||||
    Value original_input = strided_slice_op.input();
 | 
					    Value original_input = strided_slice_op.input();
 | 
				
			||||||
@ -491,6 +487,138 @@ struct ConvertTFStridedSlice : public RewritePattern {
 | 
				
			|||||||
                                strided_slice_op.shrink_axis_mask()));
 | 
					                                strided_slice_op.shrink_axis_mask()));
 | 
				
			||||||
    return matchSuccess();
 | 
					    return matchSuccess();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
 | 
				
			||||||
 | 
					                                         PatternRewriter &rewriter) const {
 | 
				
			||||||
 | 
					    TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    DenseIntElementsAttr begin_dense_elem_attr;
 | 
				
			||||||
 | 
					    Value begin = strided_slice_op.begin();
 | 
				
			||||||
 | 
					    auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    if (!begin_ranked_attr_type ||
 | 
				
			||||||
 | 
					        !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
 | 
				
			||||||
 | 
					      return matchFailure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    DenseIntElementsAttr end_dense_elem_attr;
 | 
				
			||||||
 | 
					    Value end = strided_slice_op.end();
 | 
				
			||||||
 | 
					    auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    if (!end_ranked_attr_type ||
 | 
				
			||||||
 | 
					        !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
 | 
				
			||||||
 | 
					      return matchFailure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    DenseIntElementsAttr stride_dense_elem_attr;
 | 
				
			||||||
 | 
					    Value stride = strided_slice_op.strides();
 | 
				
			||||||
 | 
					    auto stride_ranked_attr_type =
 | 
				
			||||||
 | 
					        stride.getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    if (!stride_ranked_attr_type ||
 | 
				
			||||||
 | 
					        !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
 | 
				
			||||||
 | 
					      return matchFailure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value input = strided_slice_op.input();
 | 
				
			||||||
 | 
					    RankedTensorType input_type = input.getType().cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    const ArrayRef<int64_t> input_shape = input_type.getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const int input_size = input_shape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    const ArrayRef<int64_t> begin_shape = begin_type.getShape();
 | 
				
			||||||
 | 
					    const int begin_dim = begin_shape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (begin_dim != 1) return matchFailure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
 | 
				
			||||||
 | 
					    llvm::APInt new_end_mask = strided_slice_op.end_mask();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<int32_t, 4> padded_begin;
 | 
				
			||||||
 | 
					    SmallVector<int32_t, 4> padded_end;
 | 
				
			||||||
 | 
					    SmallVector<int32_t, 4> padded_stride;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Before the ellipsis.
 | 
				
			||||||
 | 
					    uint64_t index = 1;
 | 
				
			||||||
 | 
					    int count = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while (index < ellipsis_mask) {
 | 
				
			||||||
 | 
					      padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
 | 
				
			||||||
 | 
					      padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
 | 
				
			||||||
 | 
					      padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
 | 
				
			||||||
 | 
					      index <<= 1;
 | 
				
			||||||
 | 
					      count++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Ellipsis.
 | 
				
			||||||
 | 
					    for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
 | 
				
			||||||
 | 
					      new_begin_mask |= ellipsis_mask;
 | 
				
			||||||
 | 
					      new_end_mask |= ellipsis_mask;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Mimic the begin/end/strides mask behavior.
 | 
				
			||||||
 | 
					      padded_begin.push_back(0);
 | 
				
			||||||
 | 
					      padded_end.push_back(0);
 | 
				
			||||||
 | 
					      padded_stride.push_back(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      ellipsis_mask <<= 1;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Account for ellipsis mask.
 | 
				
			||||||
 | 
					    count++;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // After the ellipsis.
 | 
				
			||||||
 | 
					    for (; count < begin_shape[0]; ++count) {
 | 
				
			||||||
 | 
					      padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
 | 
				
			||||||
 | 
					      padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
 | 
				
			||||||
 | 
					      padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto attribute_type = rewriter.getIntegerType(64);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int full_dim_count = padded_begin.size();
 | 
				
			||||||
 | 
					    auto type =
 | 
				
			||||||
 | 
					        RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
 | 
				
			||||||
 | 
					    auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
 | 
				
			||||||
 | 
					    auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
 | 
				
			||||||
 | 
					    auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
 | 
				
			||||||
 | 
					    auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
 | 
				
			||||||
 | 
					    auto stride_op =
 | 
				
			||||||
 | 
					        rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
 | 
				
			||||||
 | 
					        op, strided_slice_op.getType(), input, begin_op.getResult(),
 | 
				
			||||||
 | 
					        end_op.getResult(), stride_op.getResult(),
 | 
				
			||||||
 | 
					        rewriter.getIntegerAttr(attribute_type, new_begin_mask),
 | 
				
			||||||
 | 
					        rewriter.getIntegerAttr(attribute_type, new_end_mask),
 | 
				
			||||||
 | 
					        rewriter.getI64IntegerAttr(0),
 | 
				
			||||||
 | 
					        rewriter.getIntegerAttr(attribute_type,
 | 
				
			||||||
 | 
					                                strided_slice_op.new_axis_mask()),
 | 
				
			||||||
 | 
					        rewriter.getIntegerAttr(attribute_type,
 | 
				
			||||||
 | 
					                                strided_slice_op.shrink_axis_mask()));
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult matchAndRewrite(Operation *op,
 | 
				
			||||||
 | 
					                                     PatternRewriter &rewriter) const override {
 | 
				
			||||||
 | 
					    // TODO(renjieliu): Consider expand the transformation for shrink
 | 
				
			||||||
 | 
					    // mask as well.
 | 
				
			||||||
 | 
					    TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle new axis mask.
 | 
				
			||||||
 | 
					    uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
 | 
				
			||||||
 | 
					    if (new_axis_mask != 0) {
 | 
				
			||||||
 | 
					      return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle ellipsis mask.
 | 
				
			||||||
 | 
					    uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue();
 | 
				
			||||||
 | 
					    if (ellipsis_mask != 0) {
 | 
				
			||||||
 | 
					      return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return matchFailure();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
 | 
					#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
 | 
				
			||||||
 | 
				
			|||||||
@ -50,11 +50,23 @@ def make_strided_slice_np_style_tests(options):
 | 
				
			|||||||
          "shape": [[21, 15, 7]],
 | 
					          "shape": [[21, 15, 7]],
 | 
				
			||||||
          "spec": [[slice(3, 7, 2), slice(None), 2]],
 | 
					          "spec": [[slice(3, 7, 2), slice(None), 2]],
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
      # Ellipsis.
 | 
					      # Ellipsis 3d.
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
          "dtype": [tf.float32],
 | 
					          "dtype": [tf.float32],
 | 
				
			||||||
          "shape": [[21, 15, 7]],
 | 
					          "shape": [[21, 15, 7]],
 | 
				
			||||||
          "spec": [[slice(3, 7, 2), Ellipsis]],
 | 
					          "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
 | 
				
			||||||
 | 
					                                                slice(3, 7, 2)],
 | 
				
			||||||
 | 
					                   [slice(1, 11, 3), Ellipsis,
 | 
				
			||||||
 | 
					                    slice(3, 7, 2)]],
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      # Ellipsis 4d.
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					          "dtype": [tf.float32],
 | 
				
			||||||
 | 
					          "shape": [[21, 15, 7, 9]],
 | 
				
			||||||
 | 
					          "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
 | 
				
			||||||
 | 
					                                                slice(3, 7, 2)],
 | 
				
			||||||
 | 
					                   [slice(1, 11, 3), Ellipsis,
 | 
				
			||||||
 | 
					                    slice(3, 7, 2)]],
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
      # All combinations.
 | 
					      # All combinations.
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user