Add ellipsis mask transformation in mlir for tflite.

PiperOrigin-RevId: 294820182
Change-Id: I7dbc5ee793e5e2a65a820847b4b0be4c73090b50
This commit is contained in:
Renjie Liu 2020-02-12 20:02:40 -08:00 committed by TensorFlower Gardener
parent 7fa9732b8d
commit f6a8bcd0f2
3 changed files with 172 additions and 8 deletions

View File

@ -453,6 +453,30 @@ func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
// CHECK: return %arg0 : tensor<3xf32> // CHECK: return %arg0 : tensor<3xf32>
} }
// CHECK-LABEL: @StridedSliceEllipsisMaskBefore
func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15x2xf32> {
%cst = constant dense<0> : tensor<2xi32>
%cst_0 = constant dense<1> : tensor<2xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
return %0 : tensor<21x15x2xf32>
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
}
// CHECK-LABEL: @StridedSliceEllipsisMaskAfter
func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> {
%cst = constant dense<0> : tensor<2xi32>
%cst_0 = constant dense<1> : tensor<2xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x15x7xf32>
return %0 : tensor<5x15x7xf32>
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
}
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask // CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> { func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32> %cst = constant dense<0> : tensor<4xi32>

View File

@ -429,13 +429,9 @@ struct ConvertTFStridedSlice : public RewritePattern {
explicit ConvertTFStridedSlice(MLIRContext *context) explicit ConvertTFStridedSlice(MLIRContext *context)
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {} : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
PatternMatchResult matchAndRewrite(Operation *op, PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const {
// TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
// mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op); TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
if (new_axis_mask == 0) return matchFailure();
// Insert a new reshape op. // Insert a new reshape op.
Value original_input = strided_slice_op.input(); Value original_input = strided_slice_op.input();
@ -491,6 +487,138 @@ struct ConvertTFStridedSlice : public RewritePattern {
strided_slice_op.shrink_axis_mask())); strided_slice_op.shrink_axis_mask()));
return matchSuccess(); return matchSuccess();
} }
PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
DenseIntElementsAttr begin_dense_elem_attr;
Value begin = strided_slice_op.begin();
auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
if (!begin_ranked_attr_type ||
!matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
return matchFailure();
}
DenseIntElementsAttr end_dense_elem_attr;
Value end = strided_slice_op.end();
auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
if (!end_ranked_attr_type ||
!matchPattern(end, m_Constant(&end_dense_elem_attr))) {
return matchFailure();
}
DenseIntElementsAttr stride_dense_elem_attr;
Value stride = strided_slice_op.strides();
auto stride_ranked_attr_type =
stride.getType().dyn_cast<RankedTensorType>();
if (!stride_ranked_attr_type ||
!matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
return matchFailure();
}
Value input = strided_slice_op.input();
RankedTensorType input_type = input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> input_shape = input_type.getShape();
const int input_size = input_shape.size();
RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> begin_shape = begin_type.getShape();
const int begin_dim = begin_shape.size();
if (begin_dim != 1) return matchFailure();
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
llvm::APInt new_end_mask = strided_slice_op.end_mask();
SmallVector<int32_t, 4> padded_begin;
SmallVector<int32_t, 4> padded_end;
SmallVector<int32_t, 4> padded_stride;
// Before the ellipsis.
uint64_t index = 1;
int count = 0;
while (index < ellipsis_mask) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
index <<= 1;
count++;
}
// Ellipsis.
for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
new_begin_mask |= ellipsis_mask;
new_end_mask |= ellipsis_mask;
// Mimic the begin/end/strides mask behavior.
padded_begin.push_back(0);
padded_end.push_back(0);
padded_stride.push_back(1);
ellipsis_mask <<= 1;
}
// Account for ellipsis mask.
count++;
// After the ellipsis.
for (; count < begin_shape[0]; ++count) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
}
auto attribute_type = rewriter.getIntegerType(64);
int full_dim_count = padded_begin.size();
auto type =
RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
auto stride_op =
rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.getType(), input, begin_op.getResult(),
end_op.getResult(), stride_op.getResult(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.shrink_axis_mask()));
return matchSuccess();
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// TODO(renjieliu): Consider expand the transformation for shrink
// mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
// Handle new axis mask.
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
if (new_axis_mask != 0) {
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
}
// Handle ellipsis mask.
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue();
if (ellipsis_mask != 0) {
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
}
return matchFailure();
}
}; };
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"

View File

@ -50,11 +50,23 @@ def make_strided_slice_np_style_tests(options):
"shape": [[21, 15, 7]], "shape": [[21, 15, 7]],
"spec": [[slice(3, 7, 2), slice(None), 2]], "spec": [[slice(3, 7, 2), slice(None), 2]],
}, },
# Ellipsis. # Ellipsis 3d.
{ {
"dtype": [tf.float32], "dtype": [tf.float32],
"shape": [[21, 15, 7]], "shape": [[21, 15, 7]],
"spec": [[slice(3, 7, 2), Ellipsis]], "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
slice(3, 7, 2)],
[slice(1, 11, 3), Ellipsis,
slice(3, 7, 2)]],
},
# Ellipsis 4d.
{
"dtype": [tf.float32],
"shape": [[21, 15, 7, 9]],
"spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
slice(3, 7, 2)],
[slice(1, 11, 3), Ellipsis,
slice(3, 7, 2)]],
}, },
# All combinations. # All combinations.
{ {