Add ellipsis mask transformation in mlir for tflite.
PiperOrigin-RevId: 294820182 Change-Id: I7dbc5ee793e5e2a65a820847b4b0be4c73090b50
This commit is contained in:
parent
7fa9732b8d
commit
f6a8bcd0f2
@ -453,6 +453,30 @@ func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|||||||
// CHECK: return %arg0 : tensor<3xf32>
|
// CHECK: return %arg0 : tensor<3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @StridedSliceEllipsisMaskBefore
|
||||||
|
func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15x2xf32> {
|
||||||
|
%cst = constant dense<0> : tensor<2xi32>
|
||||||
|
%cst_0 = constant dense<1> : tensor<2xi32>
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
|
||||||
|
return %0 : tensor<21x15x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
|
||||||
|
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
|
||||||
|
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @StridedSliceEllipsisMaskAfter
|
||||||
|
func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> {
|
||||||
|
%cst = constant dense<0> : tensor<2xi32>
|
||||||
|
%cst_0 = constant dense<1> : tensor<2xi32>
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x15x7xf32>
|
||||||
|
return %0 : tensor<5x15x7xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
|
||||||
|
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
|
||||||
|
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
||||||
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||||
%cst = constant dense<0> : tensor<4xi32>
|
%cst = constant dense<0> : tensor<4xi32>
|
||||||
|
@ -429,13 +429,9 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
|||||||
explicit ConvertTFStridedSlice(MLIRContext *context)
|
explicit ConvertTFStridedSlice(MLIRContext *context)
|
||||||
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
|
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const {
|
||||||
// TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
|
|
||||||
// mask as well.
|
|
||||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
|
|
||||||
if (new_axis_mask == 0) return matchFailure();
|
|
||||||
|
|
||||||
// Insert a new reshape op.
|
// Insert a new reshape op.
|
||||||
Value original_input = strided_slice_op.input();
|
Value original_input = strided_slice_op.input();
|
||||||
@ -491,6 +487,138 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
|||||||
strided_slice_op.shrink_axis_mask()));
|
strided_slice_op.shrink_axis_mask()));
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
|
||||||
|
PatternRewriter &rewriter) const {
|
||||||
|
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||||
|
|
||||||
|
DenseIntElementsAttr begin_dense_elem_attr;
|
||||||
|
Value begin = strided_slice_op.begin();
|
||||||
|
auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!begin_ranked_attr_type ||
|
||||||
|
!matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseIntElementsAttr end_dense_elem_attr;
|
||||||
|
Value end = strided_slice_op.end();
|
||||||
|
auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!end_ranked_attr_type ||
|
||||||
|
!matchPattern(end, m_Constant(&end_dense_elem_attr))) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseIntElementsAttr stride_dense_elem_attr;
|
||||||
|
Value stride = strided_slice_op.strides();
|
||||||
|
auto stride_ranked_attr_type =
|
||||||
|
stride.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!stride_ranked_attr_type ||
|
||||||
|
!matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value input = strided_slice_op.input();
|
||||||
|
RankedTensorType input_type = input.getType().cast<RankedTensorType>();
|
||||||
|
const ArrayRef<int64_t> input_shape = input_type.getShape();
|
||||||
|
|
||||||
|
const int input_size = input_shape.size();
|
||||||
|
|
||||||
|
RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
|
||||||
|
const ArrayRef<int64_t> begin_shape = begin_type.getShape();
|
||||||
|
const int begin_dim = begin_shape.size();
|
||||||
|
|
||||||
|
if (begin_dim != 1) return matchFailure();
|
||||||
|
|
||||||
|
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
|
||||||
|
|
||||||
|
llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
|
||||||
|
llvm::APInt new_end_mask = strided_slice_op.end_mask();
|
||||||
|
|
||||||
|
SmallVector<int32_t, 4> padded_begin;
|
||||||
|
SmallVector<int32_t, 4> padded_end;
|
||||||
|
SmallVector<int32_t, 4> padded_stride;
|
||||||
|
|
||||||
|
// Before the ellipsis.
|
||||||
|
uint64_t index = 1;
|
||||||
|
int count = 0;
|
||||||
|
|
||||||
|
while (index < ellipsis_mask) {
|
||||||
|
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
|
||||||
|
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
|
||||||
|
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
|
||||||
|
index <<= 1;
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ellipsis.
|
||||||
|
for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
|
||||||
|
new_begin_mask |= ellipsis_mask;
|
||||||
|
new_end_mask |= ellipsis_mask;
|
||||||
|
|
||||||
|
// Mimic the begin/end/strides mask behavior.
|
||||||
|
padded_begin.push_back(0);
|
||||||
|
padded_end.push_back(0);
|
||||||
|
padded_stride.push_back(1);
|
||||||
|
|
||||||
|
ellipsis_mask <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Account for ellipsis mask.
|
||||||
|
count++;
|
||||||
|
|
||||||
|
// After the ellipsis.
|
||||||
|
for (; count < begin_shape[0]; ++count) {
|
||||||
|
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
|
||||||
|
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
|
||||||
|
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto attribute_type = rewriter.getIntegerType(64);
|
||||||
|
|
||||||
|
int full_dim_count = padded_begin.size();
|
||||||
|
auto type =
|
||||||
|
RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
|
||||||
|
|
||||||
|
auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
|
||||||
|
auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
|
||||||
|
auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
|
||||||
|
auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
|
||||||
|
auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
|
||||||
|
auto stride_op =
|
||||||
|
rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||||
|
op, strided_slice_op.getType(), input, begin_op.getResult(),
|
||||||
|
end_op.getResult(), stride_op.getResult(),
|
||||||
|
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||||
|
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||||
|
rewriter.getI64IntegerAttr(0),
|
||||||
|
rewriter.getIntegerAttr(attribute_type,
|
||||||
|
strided_slice_op.new_axis_mask()),
|
||||||
|
rewriter.getIntegerAttr(attribute_type,
|
||||||
|
strided_slice_op.shrink_axis_mask()));
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// TODO(renjieliu): Consider expand the transformation for shrink
|
||||||
|
// mask as well.
|
||||||
|
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||||
|
|
||||||
|
// Handle new axis mask.
|
||||||
|
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
|
||||||
|
if (new_axis_mask != 0) {
|
||||||
|
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle ellipsis mask.
|
||||||
|
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue();
|
||||||
|
if (ellipsis_mask != 0) {
|
||||||
|
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
|
||||||
|
}
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
||||||
|
@ -50,11 +50,23 @@ def make_strided_slice_np_style_tests(options):
|
|||||||
"shape": [[21, 15, 7]],
|
"shape": [[21, 15, 7]],
|
||||||
"spec": [[slice(3, 7, 2), slice(None), 2]],
|
"spec": [[slice(3, 7, 2), slice(None), 2]],
|
||||||
},
|
},
|
||||||
# Ellipsis.
|
# Ellipsis 3d.
|
||||||
{
|
{
|
||||||
"dtype": [tf.float32],
|
"dtype": [tf.float32],
|
||||||
"shape": [[21, 15, 7]],
|
"shape": [[21, 15, 7]],
|
||||||
"spec": [[slice(3, 7, 2), Ellipsis]],
|
"spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
|
||||||
|
slice(3, 7, 2)],
|
||||||
|
[slice(1, 11, 3), Ellipsis,
|
||||||
|
slice(3, 7, 2)]],
|
||||||
|
},
|
||||||
|
# Ellipsis 4d.
|
||||||
|
{
|
||||||
|
"dtype": [tf.float32],
|
||||||
|
"shape": [[21, 15, 7, 9]],
|
||||||
|
"spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
|
||||||
|
slice(3, 7, 2)],
|
||||||
|
[slice(1, 11, 3), Ellipsis,
|
||||||
|
slice(3, 7, 2)]],
|
||||||
},
|
},
|
||||||
# All combinations.
|
# All combinations.
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user