Add a pass in mlir optimization for strided_slice.

currently tflite kernel does not support strided_slice with non-zero new_axis_mask case, we can add a reshape op right before the strided_slice. it should be always safe to do so.

PiperOrigin-RevId: 258891977
This commit is contained in:
Renjie Liu 2019-07-18 19:28:42 -07:00 committed by TensorFlower Gardener
parent 0939d54145
commit a355cfad30
2 changed files with 115 additions and 0 deletions

View File

@ -147,3 +147,32 @@ func @FuseFullyConnectedAdd(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>)
// CHECK: %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
// CHECK: %2 = "tfl.fully_connected"(%0, %1, %cst)
}
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32>
%cst_0 = constant dense<1> : tensor<4xi32>
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32>
%1 = "tfl.strided_slice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
return %1 : tensor<1x2x3x1xf32>
// CHECK: %cst = constant dense<0> : tensor<4xi32>
// CHECK: %cst_0 = constant dense<1> : tensor<4xi32>
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32>
// CHECK: %1 = "tfl.strided_slice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
}
// CHECK-LABEL: @PadStridedSliceNewAxisMask
func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32>
%cst_0 = constant dense<1> : tensor<4xi32>
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.strided_slice"(%0, %cst, %cst, %cst_0) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 9 : i32, shrink_axis_mask = 0 : i32} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
return %1 : tensor<1x2x3x1xf32>
// CHECK: %cst = constant dense<0> : tensor<4xi32>
// CHECK: %cst_0 = constant dense<1> : tensor<4xi32>
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: %1 = "tfl.reshape"(%0) : (tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
// CHECK: %2 = "tfl.strided_slice"(%1, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
}

View File

@ -17,8 +17,12 @@ limitations under the License.
// optimizes them to resulting operations in TensorFlowLite dialect.
#include <climits>
#include <cstdint>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
@ -117,6 +121,87 @@ struct FuseFullyConnectedAndAdd : public RewritePattern {
}
};
// StridedSlice can have complicated atributes like begin_axis_mask,
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
// masks will complicate the strided_slice computation logic, we can simplify
// the logic by inserting a reshape op to pad the inputs so strided_slice can
// be easier to handle.
//
// So the graph may looks like below:
// original_input -> strided_slice -> output
// (transforms)
// original_input -> reshape -> strided_slice -> output
//
// And the new shape is computed based on the masks.
//
// An example for new_axis_mask. say the new_axis_mask is 9 which represents
// [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
// if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
struct PadStridedSliceDims : public RewritePattern {
explicit PadStridedSliceDims(MLIRContext *context)
: RewritePattern(TFL::StridedSliceOp::getOperationName(),
{"tfl.strided_slice", "tfl.strided_slice"}, 2, context) {
}
PatternMatchResult matchAndRewrite(Operation *strided_slice_op,
PatternRewriter &rewriter) const override {
// TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
// mask as well.
TFL::StridedSliceOp strided_slice =
llvm::cast<TFL::StridedSliceOp>(strided_slice_op);
const uint64_t new_axis_mask = strided_slice.new_axis_mask().getZExtValue();
if (new_axis_mask == 0) return matchFailure();
// Insert a new reshape op.
Value *original_input = strided_slice.input();
const RankedTensorType &original_input_type =
original_input->getType().template cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
const RankedTensorType &begin_type =
strided_slice.begin()->getType().template cast<RankedTensorType>();
const int dim_size = begin_type.getShape()[0];
SmallVector<int64_t, 4> new_shape;
int mask = 1;
int index = 0;
for (int i = 0; i < dim_size; ++i) {
if (mask & new_axis_mask) {
new_shape.emplace_back(1);
} else {
new_shape.emplace_back(original_input_shape[index]);
++index;
}
mask = mask << 1;
}
auto new_output_type =
rewriter.getTensorType(new_shape, original_input_type.getElementType());
TFL::ReshapeOp reshape = rewriter.create<TFL::ReshapeOp>(
strided_slice.getLoc(), new_output_type, original_input);
// Replace the original strided_slice.
llvm::APInt new_begin_mask = strided_slice.begin_mask();
llvm::APInt new_end_mask = strided_slice.end_mask();
// Since we expand the dims, we need to apply them to the begin_mask &
// end_mask.
new_begin_mask |= strided_slice.new_axis_mask();
new_end_mask |= strided_slice.new_axis_mask();
auto attribute_type = rewriter.getIntegerType(32);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
strided_slice_op, strided_slice.getType(), reshape,
strided_slice.begin(), strided_slice.end(), strided_slice.strides(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
rewriter.getIntegerAttr(attribute_type, strided_slice.ellipsis_mask()),
rewriter.getI32IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice.shrink_axis_mask()));
return matchSuccess();
}
};
void Optimize::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
@ -124,6 +209,7 @@ void Optimize::runOnFunction() {
TFL::populateWithGenerated(&getContext(), &patterns);
patterns.push_back(
llvm::make_unique<FuseFullyConnectedAndAdd>(&getContext()));
patterns.push_back(llvm::make_unique<PadStridedSliceDims>(&getContext()));
applyPatternsGreedily(func, std::move(patterns));
}